function [best_model,Errors] = evalsvm(arg1,arg2,arg3)
% EVALSVM Trains and evaluates Support Vector Machines classifier.
%
% Synopsis:
% [model,Errors] = evalsvm(data,options)
% [model,Errors] = evalsvm(trn_data,val_data,options)
%
% Description:
% [model,Errors] = evalsvm(data,options) uses cross-validation
% to assess SVM classifiers with given kernel arguments and
% regularization constants.
%
% The kernel type is given in options.ker (see 'help kernel').
% The SVM solver to be used is specified by field options.solver
% (default 'smo'). Both binary and multi-class SVM solvers are
% allowed. The input data have the following format:
% be used with regards to the number of labels of training data
% data.X [dim x num_data] ... training vectors.
% data.y [1 x num_data] ... labels.
% The set of SVM parameters to be evaluated are specified in:
% options.arg [dimarg x nargs] ... enumeration of kernel arguments;
% dimarg determins number of kernel argumens (e.g., dimarg = 1
% for 'rbf' kernel and dimarg = 2 for 'sigmoid').
% options.C [1 x nc] ... enumeration of regularization constants.
%
% Some extra parameters for the selected SVM solver can be
% specified in the field options.solver_options.
%
% Each configuration of SVM paramaters is evaluated using the
% cross-validation. The number of folds is given in
% optios.num_folds (default 5). The trained SVM model with
% the smallest cross-validation error is returned. The computed
% cross-validation errors with respect to SVM parametes are
% returned in Errors [nc x nargs].
%
% The progress info is displayed if options.verb is set to 1
% (default 0).
%
% [model,Errors] = evalsvm(trn_data,val_data,options) each
% SVM is trained on the trn_data and evaluated on the
% validation val_data instead of using cross-validation.
%
% Example:
% trn = load('riply_trn');
% tst = load('riply_tst');
% options.ker = 'rbf';
% options.arg = [0.1 0.5 1 5];
% options.C = [1 10 100];
% options.solver = 'smo';
% options.num_folds = 5;
% options.verb = 1;
% [model,Errors] = evalsvm(trn,options);
% figure; mesh(options.arg,options.C,Errors);
% hold on; xlabel('arg'); ylabel('C');
% ypred = svmclass(tst.X,model);
% cerror(ypred,tst.y)
%
% See also:
% SMO, SVMLIGHT, SVMCLASS.
%
% Modifications:
% 17-sep-2004, VF, Help improved. Info about training stage added.
% 18-aug-2004, VF, svm_options changed to solver_options
% 4-june-2004, VF
% 3-jun-2004, VF
if nargin == 2,
options = c2s(arg2);
if ~isfield(options,'verb'), options.verb = 0; end
if ~isfield(options,'solver'), options.solver = 'smo'; end
if ~isfield(options,'ker'), options.ker = 'linear'; end
if ~isfield(options,'num_folds'), options.num_folds = 5; end
if ~isfield(options,'solver_options'), options.solver_options = []; end
nargs = size(options.arg,2);
nc = length(options.C);
[dim,num_data] = size(arg1.X);
solver_options = options.solver_options;
solver_options.ker = options.ker;
min_error = inf;
Errors = [];
cnt_model = 0;
num_model = nc*nargs;
[itrn,itst] = crossval(num_data,options.num_folds);
for i = 1:nargs,
arg = options.arg(:,i);
for j = 1:nc,
cnt_model = cnt_model + 1;
C = options.C(j);
solver_options.C = C;
solver_options.arg = arg;
if options.verb == 1,
fprintf('Model %d/%d: ker=%s, C=%f, arg=', ...
cnt_model, num_model,solver_options.ker, C);
fprintf('%f ', arg);
fprintf('\n');
end
fold_error = zeros(options.num_folds,1);
for k=1:options.num_folds,
if options.verb == 1,
fprintf('fold %d/%d: #trn/tst = %d/%d, training', ...
k, options.num_folds, length(itrn{k}), length(itst{k}));
end
trn.X = arg1.X(:,itrn{k});
trn.y = arg1.y(:,itrn{k});
tst.X = arg1.X(:,itst{k});
tst.y = arg1.y(:,itst{k});
model = feval( options.solver, trn, solver_options );
if options.verb == 1,
ypred = feval( model.fun, trn.X, model );
fprintf(' err = %.4f, testing', cerror( ypred,trn.y));
end
ypred = feval( model.fun, tst.X, model );
fold_error(k) = cerror( ypred, tst.y );
if options.verb == 1,
fprintf(' err = %.4f\n', fold_error(k));
end
end
err = mean( fold_error );
Errors(j,i) = err;
if min_error > err,
min_error = err;
best_model = model;
if options.verb == 1,
fprintf('cross-validation error = %.4f (best so far)\n\n', err);
end
else
if options.verb == 1,
fprintf('cross-validation error = %.4f\n\n', err);
end
end
end
end
if options.verb == 1,
fprintf('best model: ker=%s, C=%f, arg=', ...
best_model.options.ker, best_model.options.C);
fprintf('%f ', best_model.options.arg);
fprintf('\ncross-validation error = %.4f\n', min_error);
end
elseif nargin == 3,
options = c2s(arg3);
if ~isfield(options,'verb'), options.verb = 0; end
if ~isfield(options,'ker'), options.ker = 'linear'; end
if ~isfield(options,'solver'), options.solver = 'smo'; end
if ~isfield(options,'solver_options'), options.solver_options = []; end
nargs = size(options.arg,2);
nc = length(options.C);
[dim,num_data] = size(arg1.X);
solver_options = options.solver_options;
solver_options.ker = options.ker;
min_error = inf;
Errors = [];
cnt_model = 0;
num_model = nc*nargs;
for i = 1:nargs,
arg = options.arg(:,i);
for j = 1:nc,
cnt_model = cnt_model + 1;
C = options.C(j);
solver_options.C = C;
solver_options.arg = arg;
if options.verb == 1,
fprintf('Model %d/%d: ker=%s, C=%f, arg=', ...
cnt_model, num_model,solver_options.ker, C);
fprintf('%f ', arg);
fprintf('\n');
end
if options.verb == 1,
fprintf('#trn/tst = %d/%d, training', ...
length(arg1.y), length(arg2.y));
end
model = feval( options.solver, arg1, solver_options );
if options.verb == 1,
ypred = feval( model.fun, arg1.X, model );
fprintf(' err = %.4f, testing', cerror( ypred,arg1.y));
end
ypred = feval( model.fun, arg2.X, model );
err = cerror(ypred,arg2.y);
Errors(j,i) = err;
if min_error > err,
min_error = err;
best_model = model;
if options.verb == 1,
fprintf(' err = %.4f (best so far)\n\n', err);
end
else
if options.verb == 1,
fprintf(' err = %.4f\n\n', err);
end
end
end
end
if options.verb == 1,
fprintf('best model: ker=%s, C=%f, arg=', ...
best_model.options.ker, best_model.options.C);
fprintf('%f ', best_model.options.arg);
fprintf('\ntesting error = %.4f\n', min_error);
end
end