cosmo classify matlabsvm 2class

function predicted = cosmo_classify_matlabsvm_2class(samples_train, targets_train, samples_test, opt)
    % svm classifier wrapper (around svmtrain/svmclassify)
    %
    % predicted=cosmo_classify_matlabsvm_2class(samples_train, targets_train, samples_test, opt)
    %
    % Inputs:
    %   samples_train      PxR training data for P samples and R features
    %   targets_train      Px1 training data classes
    %   samples_test       QxR test data
    %   opt                struct with options. supports any option that
    %                      svmtrain supports
    %
    % Output:
    %   predicted          Qx1 predicted data classes for samples_test
    %
    % Notes:
    %  - this function uses Matlab's builtin svmtrain function, which has
    %    the same name as LIBSVM's version. Use of this function is not
    %    supported when LIBSVM's svmtrain precedes in the matlab path; in
    %    that case, adjust the path or use cosmo_classify_libsvm instead.
    %  - Matlab's SVM classifier is rather slow, especially for multi-class
    %    data (more than two classes). When classification takes a long time,
    %    consider using libsvm.
    %  - for a guide on svm classification, see
    %      http://www.csie.ntu.edu.tw/~cjlin/papers/guide/guide.pdf
    %    Note that cosmo_crossvalidate and cosmo_crossvalidation_measure
    %    provide an option 'normalization' to perform data scaling
    %  - As of Matlab 2017a (maybe earlier), Matlab gives the warning that
    %      'svmtrain will be removed in a future release. Use fitcsvm instead.'
    %    however fitcsvm gives different results than svmtrain; as a result
    %    cosmo_classify_matlabcsvm gives different results than
    %    cosmo_classify_matlabsvm. In this function the warning message is
    % .  suppressed.
    %  - As of Matlab 2018a, this function cannot be used anymore. Use
    %    cosmo_classify_matlabcsvm instead.
    %
    % See also svmtrain, svmclassify, cosmo_classify_matlabsvm,
    %          cosmo_classify_matlabcsvm
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #

    if nargin < 4
        opt = struct();
    end

    [ntrain, nfeatures] = size(samples_train);
    [ntest, nfeatures_] = size(samples_test);
    ntrain_ = numel(targets_train);

    if nfeatures ~= nfeatures_ || ntrain_ ~= ntrain
        error('illegal input size');
    end

    if nfeatures == 0
        % matlab's svm cannot deal with empty data, so predict all
        % test samples as the class of the first sample
        predicted = targets_train(1) * ones(ntest, 1);
        return
    end

    classes = unique(targets_train);
    nclasses = numel(classes);
    if nclasses ~= 2
        error(['%s requires 2 classes, found %d. Consider using '...
               'cosmo_classify_{matlab,lib}svm instead'], ...
              mfilename(), nclasses);
    end

    opt_cell = opt2cell(opt);

    % train & test; if it fails, see if this caused by non-functioning
    % matlabsvm

    try
        if ~cached_has_matlabsvm()
            cosmo_check_external('matlabsvm');
        end

        % Use svmtrain and svmclassify to get predictions for the testing set.

        % disable warnings shown by Matlab 2017 and later
        orig_warning_state = warning();
        cleaner = onCleanup(@()warning(orig_warning_state));
        warning('off', 'stats:obsolete:ReplaceThisWith');
        warning('off', ['stats:obsolete:ReplaceThisWith'...
                        'MethodOfObjectReturnedBy']);

        % store most recent warning
        [orig_lastmsg, orig_lastid] = lastwarn();

        model = svmtrain(samples_train, targets_train, opt_cell{:});
        predicted = svmclassify(model, samples_test);

        % deal with possible warning shown (Matlab >= 2016)
        [warning_msg, warning_id] = lastwarn();
        if strcmp(warning_id, 'stats:obsolete:ReplaceThisWith')
            % only show warning once (by default) if this is a
            % a stats:obsolete message
            suffix = ['CoSMoMVPA note: the more recent '...
                      'fitcsvm / svmsmoset classifiers produce '...
                      'different results '...
                      'than the older svmtrain function. '...
                      'To use fitcsvm, use cosmo_classify_matlabcsvm'];
            cosmo_warning('%s\n%s', warning_msg, suffix);
        elseif ~strcmp(warning_id, orig_lastid)
            % new warning was issued , different from stats:obsolete one;
            % show warning message
            cosmo_warning(warning_id, warning_msg);
        end

    catch
        caught_exception = lasterror();
        cosmo_check_external('matlabsvm');

        if strcmp(caught_exception.identifier, ...
                  'stats:svmtrain:NoConvergence')
            error(['SVM training did not converge. Your options are:\n'...
                   ' 1) increase ''boxconstraint''\n'...
                   ' 2) increase ''tolkkt''\n'...
                   ' 3) set ''kktviolationlevel'' to a positive value\n'...
                   ' 4) use a different classifier\n'...
                   'If you do not have a strong preference for '...
                   'either option, you are advised to try option (4) '...
                   'using cosmo_classify_lda'], '');
        else
            rethrow(caught_exception);
        end
    end

    % helper function to convert cell to struct
function opt_cell = opt2cell(opt)

    if isempty(opt)
        opt_cell = cell(0);
        return
    end

    to_keep = {'kernel_function', ...
               'rbf_sigma', ...
               'polyorder', ...
               'mlp_params', ...
               'method', ...
               'options', ...
               'tolkkt', ...
               'kktviolationlevel', ...
               'kernelcachelimit', ...
               'boxconstraint', ...
               'autoscale', ...
               'showplot'};

    fns = fieldnames(opt);
    keep_msk = cosmo_match(fns, to_keep);
    keep_fns = fns(keep_msk);
    keep_id = find(keep_msk);

    n = numel(keep_fns);
    opt_cell = cell(1, 2 * n);
    for k = 1:n
        fn = fns{keep_id(k)};
        opt_cell{k * 2 - 1} = fn;
        opt_cell{k * 2} = opt.(fn);
    end

function tf = cached_has_matlabsvm()
    persistent cached_tf

    if isequal(cached_tf, true)
        tf = true;
        return
    end

    cached_tf = cosmo_check_external('matlabsvm');
    tf = cached_tf;