function demo_emgmm(action,hfigure,varargin)
% DEMO_EMGMM Demo on Expectation-Maximization (EM) algorithm.
%
% Synopsis:
% demo_emgmm
%
% Description:
% This demo shows the Expectation-Maximization (EM) algorithm
% [Schles68][DLR77] for Gaussians mixture model (GMM). The EM
% fits the GMM to i.i.d. sample data (in this case only 2D)
% such that the likelihood is maximized.
%
% The found model is described by ellipsoids (shape of
% covariances) and a crosses (mean value vectors). The value
% of the optimized log-likelihood function for the current estimate
% is displayed in the bottom part.
%
% Control:
% Covariance - Determines type of the covariance matrix:
% Diagonal (independent features),
% Full (correlated features).
% Components - Number of components (Gaussians) in the mixture.
%
% Iterations - Number of iterations in one step.
% Random init - the initial model is randomly generated and/or
% first n training samples are taken as the
% mean vectors.
%
% FIG2EPS - Export screen to the PostScript file.
% Save model - Save current model to file.
% Load data - Load input point sets from file.
% Create data - Invoke program for creating point sets.
% Reset - Set the tested algorithm to the initial state.
% Play - Run the tested algorithm.
% Stop - Stop the running algorithm.
% Step - Perform only one step.
% Info - Info box.
% Close - Close the program.
%
% See also EMGMM.
%
% About: Statistical Pattern Recognition Toolbox
% (C) 1999-2003, Written by Vojtech Franc and Vaclav Hlavac
% <a href="http://www.cvut.cz">Czech Technical University Prague</a>
% <a href="http://www.feld.cvut.cz">Faculty of Electrical Engineering</a>
% <a href="http://cmp.felk.cvut.cz">Center for Machine Perception</a>
% Modifications:
% 19-sep-2003, VF
% 11-june-2001, V.Franc, comments added.
% 27.02.00 V. Franc
% 5. 4.00 V. Franc
% 23.06.00 V. Hlavac Comments polished. Message when no data loaded.
% Export of the solution to global variables.
% 27-mar-2001, V.Franc, Graph og log-likelihood function added
global UNSU_MI
global UNSU_SIGMA
global UNSU_PK
global UNSU_solution
global UNSU_t
AXIST_ADD=10;
AXISY_ADD=5;
BORDER=0.25;
CENTERSIZE=10;
LINE_WIDTH=1;
AXIST_ADD=10;
DATA_IDENT='Finite sets, enumeration';
randinit=1;
if nargin < 1,
action = 'initialize';
end
switch lower(action)
case 'initialize'
left=0.2;
width=0.6;
bottom=0.1;
height=0.8;
hfigure=figure('Name','EM algorithm', ...
'Visible','off',...
'NumberTitle','off', ...
'Units','normalized', ...
'Position',[left bottom width height],...
'tag','demo_emgmm',...
'doublebuffer','on',...
'backingstore','off');
left=0.1;
width=0.65;
bottom=0.45;
height=0.5;
haxes1=axes(...
'Units','normalized', ...
'NextPlot','add',...
'UserData',[],...
'Position',[left bottom width height]);
xlabel('feature x\_1');
ylabel('feature x\_2');
htitle1=title('No data loaded',...
'VerticalAlignment','bottom',...
'Parent',haxes1,...
'HorizontalAlignment','left',...
'Units','normalized',...
'Position',[0 1 0]);
left=0.1;
width=0.65;
bottom=0.1;
height=0.25;
haxes2=axes(...
'Units','normalized', ...
'NextPlot','add',...
'Position',[left bottom width height]);
ylabel('logL(t)');
htitle2=title('Log-likelihood function',...
'Parent',haxes2,...
'VerticalAlignment','bottom',...
'Units','normalized',...
'HorizontalAlignment','left',...
'Position',[0 1 0]);
htxsteps=xlabel('step number t=0');
width=0.1;
left=0.75-width;
bottom=0.95;
height=0.04;
hbtclose = uicontrol(...
'Units','Normalized', ...
'Callback','fig2eps(gcf)',...
'ListboxTop',0, ...
'Position',[left bottom width height], ...
'String','FIG2EPS');
left=0.8;
bottom=0.05;
height=0.044;
width=0.15;
hbtclose = uicontrol(...
'Units','Normalized', ...
'Callback','close(gcf)',...
'ListboxTop',0, ...
'Position',[left bottom width height], ...
'String','Close');
bottom=bottom+1.5*height;
hbtinfo = uicontrol(...
'Units','Normalized', ...
'Callback','demo_emgmm(''info'',gcf)',...
'ListboxTop',0, ...
'Position',[left bottom width height], ...
'String','Info');
bottom=bottom+1.5*height;
hbtstep = uicontrol(...
'Units','Normalized', ...
'ListboxTop',0, ...
'Position',[left bottom width height], ...
'String','Step', ...
'Interruptible','off',...
'Callback','demo_emgmm(''step'',gcf)');
bottom=bottom+height;
hbtstop = uicontrol(...
'Units','Normalized', ...
'ListboxTop',0, ...
'Position',[left bottom width height], ...
'String','Stop', ...
'Callback','set(gcbo,''UserData'',1)',...
'Enable','off');
bottom=bottom+height;
hbtplay = uicontrol(...
'Units','Normalized', ...
'ListboxTop',0, ...
'Position',[left bottom width height], ...
'String','Play', ...
'Callback','demo_emgmm(''play'',gcf)');
bottom=bottom+height;
hbtreset = uicontrol(...
'Units','Normalized', ...
'ListboxTop',0, ...
'Position',[left bottom width height], ...
'String','Reset', ...
'Callback','demo_emgmm(''reset'',gcf)');
bottom=bottom+1.5*height;
hbtcreat = uicontrol(...
'Units','Normalized', ...
'ListboxTop',0, ...
'Position',[left bottom width height], ...
'String','Create data', ...
'Callback','demo_emgmm(''creatdata'',gcf)');
bottom=bottom+1*height;
hbtload = uicontrol(...
'Units','Normalized', ...
'ListboxTop',0, ...
'Position',[left bottom width height], ...
'String','Load data', ...
'Callback','demo_emgmm(''getfile'',gcf)');
bottom=bottom+1.5*height;
hbtSaveModel = uicontrol(...
'Units','Normalized', ...
'ListboxTop',0, ...
'Position',[left bottom width height], ...
'String','Save model', ...
'Callback','demo_emgmm(''savemodel'',gcf)');
bottom=0.95-height;
htxfeatures=uicontrol( ...
'Style','text', ...
'Units','normalized', ...
'Position',[left bottom width height], ...
'String','Covariance');
bottom=bottom-height;
hpufeatures=uicontrol( ...
'Style','popup', ...
'Units','normalized', ...
'Position',[left bottom width height], ...
'String',['Diagonal '; 'Full ']);
bottom=bottom-1.3*height;
htxclasses=uicontrol( ...
'Style','text', ...
'Units','normalized', ...
'Position',[left bottom width 0.9*height], ...
'String','Components');
bottom=bottom-height;
hedclasses = uicontrol(...
'Units','normalized', ...
'ListboxTop',0, ...
'Position',[left bottom width height], ...
'Style','edit',...
'String','2');
bottom=bottom-1.3*height;
htxiter=uicontrol( ...
'Style','text', ...
'Units','normalized', ...
'Position',[left bottom width 0.9*height], ...
'String','Iterations');
bottom=bottom-height;
hediter = uicontrol(...
'Units','normalized', ...
'ListboxTop',0, ...
'Position',[left bottom width height], ...
'Style','edit',...
'String','1');
bottom=bottom-height*1.3;
hxbrandom = uicontrol(...
'Style','checkbox', ...
'Units','normalized', ...
'ListboxTop',0, ...
'Position',[left bottom width height], ...
'String','Random init');
handlers=struct(...
'ellipse',struct('handler',-1,'mi',[],'sigma',[],'t',0,'Pk',[],...
'solution',0),...
'center',[],...
'graph1',struct('handler',-1,'loglik',[],'axist',0,'time',[]),...
'title1',htitle1,...
'title2',htitle2,...
'btSaveModel',hbtSaveModel,...
'btstep',hbtstep,...
'btstop',hbtstop,...
'btclose',hbtclose,...
'btplay',hbtplay,...
'btreset',hbtreset,...
'btinfo',hbtinfo,...
'btload',hbtload,...
'btcreat',hbtcreat,...
'txsteps',htxsteps,...
'txclasses',htxclasses,...
'txiter',htxiter,...
'txfeatures',htxfeatures,...
'pufeatures',hpufeatures,...
'editer',hediter,...
'xbrandom',hxbrandom,...
'axes1',haxes1,...
'axes2',haxes2,...
'edclasses',hedclasses);
set(hfigure,'UserData',handlers)
demo_emgmm('reset',hfigure);
set(hfigure,'Visible','on');
drawnow;
case 'savemodel'
h=get(hfigure,'UserData');
if h.ellipse.t == 0,
errordlg('No model has found yet.','No model to save','modal');
return;
end
[name,path]=uiputfile('*.mat','Save model');
if name ~= 0,
Mean = h.ellipse.mi;
Cov = reshape(h.ellipse.sigma,2,2,size(Mean,2));
Prior = h.ellipse.Pk;
fun = 'pdfgmm';
save(strcat(path,name),'Mean','Cov','Prior','fun');
end
case 'loadmodel'
h=get(hfigure,'UserData');
sets=get(h.axes1,'UserData');
if isempty(sets)==1,
return;
end
[name,path]=uigetfile('*.mat','Load model');
if name ~= 0,
fname=strcat(path,name);
demo_emgmm('reset',hfigure);
model=load(fname);
if exist('model.Pk')==0,
model.Pk=ones(1,sum(model.K))/sum(model.K);
end
h.ellipse.mi = model.MI;
h.ellipse.sigma=model.SIGMA;
h.ellipse.Pk = model.Pk;
h.ellipse.solution=0;
h.ellipse.t=0;
set(h.edclasses,'String',num2str(sum(model.K)));
h.ellipse.classes=sum(model.K);
h.ellipse.features=get(h.pufeatures,'Value');
set(hfigure,'UserData',h);
demo_emgmm('step',hfigure);
end
case 'play'
h=get(hfigure,'UserData');
sets=get(h.axes1,'UserData');
if isempty(sets)==1 | h.ellipse.solution==1,
return;
end
set([h.editer,h.btstep,h.btclose,h.btplay,...
h.btreset,h.btinfo,h.btload,h.btcreat,h.txiter],...
'Enable','off');
set(h.btstop,'Enable','on');
iter=str2num(get(h.editer,'String'));
if h.ellipse.t==0,
h.ellipse.classes=str2num(get(h.edclasses,'String'));
h.ellipse.features=get(h.pufeatures,'Value');
set([h.xbrandom,h.edclasses,h.txclasses,h.pufeatures,h.txfeatures],...
'Enable','off');
end
randinit=get(h.xbrandom,'Value');
set(h.btstop,'UserData',0);
play=1;
while play==1 & get(h.btstop,'UserData')==0,
options.rand = randinit;
switch 3-h.ellipse.features,
case 1,
options.cov_type = 'full';
case 2,
options.cov_type = 'diag';
end
options.ncomp=h.ellipse.classes;
options.tmax = iter+h.ellipse.t;
if h.ellipse.t >= 1,
init_model.logL = h.ellipse.logL;
init_model.Alpha = h.ellipse.Alpha;
init_model.t = h.ellipse.t;
init_model.Mean = h.ellipse.mi;
init_model.Cov = reshape(h.ellipse.sigma,2,2,size(h.ellipse.mi,2));
init_model.Prior = h.ellipse.Pk;
model=emgmm(sets.X,options,init_model);
else
model=emgmm(sets.X,options);
end
h.ellipse.mi = model.Mean;
h.ellipse.sigma = reshape( model.Cov,2,2*size(model.Mean,2));
h.ellipse.Pk = model.Prior;
[tmp,eI]=max(model.Alpha);
h.ellipse.solution = model.exitflag;
h.ellipse.t = model.t;
h.ellipse.Alpha = model.Alpha;
h.ellipse.logL = model.logL(end);
text=sprintf('step number t=%d ',h.ellipse.t);
if h.ellipse.solution==1,
text=[text ',EM has converged.'];
play=0;
set(h.txsteps,'String',text);
else
set(h.txsteps,'String',text);
val= mln(sets.X,h.ellipse.mi,h.ellipse.sigma,h.ellipse.Pk);
text=sprintf('Log-likelihood, logL(t) = %f',val);
set(h.title2,'String',text);
h.graph1.time=[h.graph1.time,h.ellipse.t];
h.graph1.loglik=[h.graph1.loglik,val];
ylimit=get(h.axes2,'YLim');
if ylimit(2) < val,
set(h.axes2,'YLim',[ylimit(1) val+AXISY_ADD]);
end
if h.ellipse.t > h.graph1.axist,
h.graph1.axist=h.ellipse.t+iter*AXIST_ADD;
set(h.axes2,'XLim',[1 h.graph1.axist]);
end
set(h.graph1.handler,'XData',h.graph1.time,'YData',h.graph1.loglik,...
'Visible','on');
if h.ellipse.handler==-1,
axes(h.axes1);
end
[h.ellipse.handler,h.center]=...
pnmix(sets.X,h.ellipse.mi,h.ellipse.sigma,eI,h.ellipse.handler,h.center);
end
set(hfigure,'UserData',h);
drawnow;
end
set([h.editer,h.btstep,h.btclose,h.btplay,...
h.btreset,h.btinfo,h.btload,h.btcreat,h.txiter],...
'Enable','on');
set(h.btstop,'Enable','off');
UNSU_MI = h.ellipse.mi;
UNSU_SIGMA = h.ellipse.sigma;
UNSU_PK = h.ellipse.Pk;
UNSU_solution = h.ellipse.solution;
UNSU_t = h.ellipse.t;
case 'step'
h=get(hfigure,'UserData');
sets=get(h.axes1,'UserData');
if isempty(sets)==1 | h.ellipse.solution==1,
return;
end
iter=str2num(get(h.editer,'String'));
if h.ellipse.t==0,
h.ellipse.classes=str2num(get(h.edclasses,'String'));
h.ellipse.features=get(h.pufeatures,'Value');
set([h.xbrandom,h.edclasses,h.txclasses,h.pufeatures,h.txfeatures],'Enable','off');
end
randinit=get(h.xbrandom,'Value');
options.rand = randinit;
switch 3-h.ellipse.features,
case 1,
options.cov_type = 'full';
case 2,
options.cov_type = 'diag';
end
options.ncomp=h.ellipse.classes;
options.tmax = iter+h.ellipse.t;
if h.ellipse.t >= 1,
init_model.logL = h.ellipse.logL;
init_model.Alpha = h.ellipse.Alpha;
init_model.t = h.ellipse.t;
init_model.Mean = h.ellipse.mi;
init_model.Cov = reshape(h.ellipse.sigma,2,2,size(h.ellipse.mi,2));
init_model.Prior = h.ellipse.Pk;
model=emgmm(sets.X,options,init_model);
else
model=emgmm(sets.X,options);
end
h.ellipse.mi = model.Mean;
h.ellipse.sigma = reshape( model.Cov,2,2*size(model.Mean,2));
h.ellipse.Pk = model.Prior;
[tmp,eI]=max(model.Alpha);
h.ellipse.solution = model.exitflag;
h.ellipse.t = model.t;
h.ellipse.Alpha = model.Alpha;
h.ellipse.logL = model.logL(end);
text=sprintf('step number t=%d ',h.ellipse.t);
if h.ellipse.solution==1,
text=[text ',EM has converged.'];
set(h.txsteps,'String',text);
else
set(h.txsteps,'String',text);
val= mln(sets.X,h.ellipse.mi,h.ellipse.sigma,h.ellipse.Pk);
text=sprintf('Log-likelihood, logL(t) = %f',val);
set(h.title2,'String',text);
h.graph1.time=[h.graph1.time,h.ellipse.t];
h.graph1.loglik=[h.graph1.loglik,val];
ylimit=get(h.axes2,'YLim');
if ylimit(2) < val,
set(h.axes2,'YLim',[ylimit(1) val+AXISY_ADD]);
end
if h.ellipse.t > h.graph1.axist,
h.graph1.axist=h.ellipse.t+iter*AXIST_ADD;
set(h.axes2,'XLim',[1 h.graph1.axist]);
end
set(h.graph1.handler,'XData',h.graph1.time,'YData',h.graph1.loglik,...
'Visible','on');
if h.ellipse.handler==-1,
axes(h.axes1);
end
[h.ellipse.handler,h.center]=...
pnmix(sets.X,h.ellipse.mi,h.ellipse.sigma,eI,h.ellipse.handler,h.center);
end
drawnow;
set(hfigure,'UserData',h);
UNSU_MI = h.ellipse.mi;
UNSU_SIGMA = h.ellipse.sigma;
UNSU_PK = h.ellipse.Pk;
UNSU_solution = h.ellipse.solution;
UNSU_t = h.ellipse.t;
case 'getfile'
h=get(hfigure,'UserData');
[name,path]=uigetfile('*.mat','Open file');
if name~=0,
file.pathname=strcat(path,name);
file.path=path;
file.name=name;
if check2ddata(file.pathname)==1,
set(h.btload,'UserData',file);
demo_emgmm('loadsets',hfigure);
else
errordlg('This file does not contain required data.','Bad file','modal');
end
end
case 'loadsets'
h=get(hfigure,'UserData');
file=get(h.btload,'UserData');
sets=load(file.pathname);
sets.K = size(sets.X,2);
sets.N = size(sets.X,1);
set(h.axes1,'UserData',sets);
demo_emgmm('reset',hfigure);
drawnow;
case 'reset'
h=get(hfigure,'UserData');
file=get(h.btload,'UserData');
sets=get(h.axes1,'UserData');
h.ellipse.mi=[];
h.ellipse.sigma=[];
h.ellipse.t=0;
h.ellipse.Pk=[];
h.center=-1;
h.ellipse.handler=-1;
h.ellipse.solution=0;
h.graph1.time=[];
h.graph1.axist=0;
h.graph1.loglik=[];
clrchild(h.axes2);
axes(h.axes2);
axis auto;
h.graph1.handler=plot([0],[0],'b','Parent',h.axes2,...
'EraseMode','background','Visible','off');
set(hfigure,'UserData',h);
set([h.xbrandom,h.edclasses,h.txclasses,h.pufeatures,h.txfeatures],'Enable','on');
text=sprintf('Log-likelihood, logL(t)');
set(h.title2,'String',text);
text=sprintf('step number t=0. ');
set(h.txsteps,'String',text);
set(get(h.axes1,'Children'),'EraseMode','normal');
clrchild(h.axes1);
drawnow;
axes(h.axes1);
if isempty(sets)==0,
win=cmpwin(min(sets.X'),max(sets.X'),BORDER,BORDER);
setaxis(h.axes1,win);
ppatterns(sets.X);
end
if isempty(sets)==0,
set(h.title1,'String',sprintf('File: %s, # of points K = %d',file.name,sum(sets.K)));
else
set(h.title1,'String','No data loaded');
pos=get(h.axes1,'Position');
fsize=min(pos(3),pos(4))/7;
setaxis(h.axes1,[-1 1 -1 1]);
builtin('text',0,0,'Press ''Load data'' button.',...
'Parent',h.axes1,...
'HorizontalAlignment','center',...
'FontUnits','normalized',...
'Clipping','on',...
'FontSize',fsize);
builtin('text',0,-fsize*2,...
'Load sample data from ../toolboxroot/data/gmm\_samples/ ',...
'Parent',h.axes1,...
'HorizontalAlignment','center',...
'FontUnits','normalized',...
'Clipping','on',...
'FontSize',fsize*0.65);
end
drawnow;
case 'creatdata'
createdata('finite',10,'demo_emgmm','created',hfigure);
case 'created'
figure(hfigure);
h=get(hfigure,'UserData');
path=varargin{1};
name=varargin{2};
pathname=strcat(path,name);
if check2ddata(pathname)==1,
file.pathname=pathname;
file.path=path;
file.name=name;
set(h.btload,'UserData',file);
demo_emgmm('loadsets',hfigure);
else
errordlg('This file does not contain required data.','Bad file','modal');
end
case 'info'
helpwin(mfilename);
end
function []=clrchild(handle)
delete(get(handle,'Children'));
return;
function []=setaxis(handle,rect)
set(handle,'XLim',rect(1:2));
set(handle,'YLim',rect(3:4));
if size(rect,2)>=6,
set(handle,'ZLim',rect(5:6));
end
return;
function [win]=cmpwin(mins,maxs,xborder,yborder)
dx=max( (maxs(1)-mins(1)), 1 )*xborder;
dy=max( (maxs(2)-mins(2)), 1 )*yborder;
x1=(mins(1)-dx);
x2=(maxs(1)+dx);
y1=(mins(2)-dy);
y2=(maxs(2)+dx);
win=[x1 x2 y1 y2];
function [logL] = mln(X,MI,SIGMA,Pk)
D=size(MI,1);
K=size(MI,2);
N=size(X,2);
A=zeros(N,K);
for k=1:K,
pxk=normald(X,MI(:,k),SIGMA(:,1+(k-1)*D:k*D));
A(:,k)=pxk(:)*Pk(k);
end
logL=sum(log(sum(A,2)));
function [p]=normald(X,mi,sigma)
DIM=size(X,1);
p=exp(-1/2*mahalan(X,mi,sigma))/((2*pi)^(DIM/2) * sqrt(det(sigma)));
function [hellipse,hcenter]=pnmix(X,MI,SIGMA,I,hellipse,hcenter)
if nargin < 5,
hellipse=-1;
end
DIM=size(X,1);
N=size(X,2);
K=size(MI,2);
maxr=zeros(1,K);
for i=1:N,
r=sqrt(mahalan(X(:,i),MI(:,I(i)),SIGMA(:,(I(i)-1)*DIM+1:DIM*I(i))));
if maxr(I(i)) < r,
maxr(I(i)) = r;
end
end
if hellipse==-1,
for i=1:K,
[x,y]=ellips(MI(:,i),inv(SIGMA(:,(i-1)*DIM+1:DIM*i)),maxr(i),30);
hellipse(i)=plot(x,y,'k','EraseMode','xor');
hcenter(i)=plot(MI(1,i),MI(2,i),'+k','EraseMode','xor');
drawnow;
end
else
for i=1:K,
[x,y]=ellips(MI(:,i),inv(SIGMA(:,(i-1)*DIM+1:DIM*i)),maxr(i),30);
set(hellipse(i),'XData',x,'YData',y,'Visible','on');
set(hcenter(i),'XData',MI(1,i),'YData',MI(2,i),'Visible','on');
drawnow;
end
end