cosmo match skl

function msk = cosmo_match(haystack, needle, varargin)
    % returns a mask indicating matching occurrences in two arrays or cells
    % relative to the second array
    %
    % msk=cosmo_match(haystack1, needle1[, needle2, haystack2, ...])
    %
    % Inputs:
    %   haystack*         numeric vector, or cell with strings. A string is
    %                     also allowed and interpreted as the name of a feature
    %                     dimension ('i','j' or 'k' in fmri datasets; 'chan',
    %                     'time', or 'freq' in MEEG datasets), and its
    %                     respective values (from ds.a.fdim.values{dim}, where
    %                     dim is the dimension corresponding to haystack) as
    %                     indexed by ds.fa.(haystack) are used as haystack.
    %   needle*           numeric vector, or cell with strings. A string is
    %                     also allowed and interpreted as {needle}.
    %
    % Output:
    %   msk               boolean array of the same size as haystack, with
    %                     true where the value in haystack is equal to at least
    %                     one value in needle. If multiple needle/haystack
    %                     pairs are provided, then the haystack inputs should
    %                     have the same number of elements, and msk contains
    %                     the intersection of the individual masks.
    %
    % Examples
    %     % simple character comparison
    %     cosmo_match({'a','b','c'},{'b','c','d','e','b'})
    %     %|| [false true true]
    %     % swap the order of inputs
    %     cosmo_match({'b','c','d','e','b'},{'a','b','c'})
    %     %|| [true true false false true]
    %
    %     % in an fMRI dataset, get mask for features with with first spatial
    %     % dimension indices 5 or 7.
    %     msk=cosmo_match(ds.fa.i,[5 7]);
    %
    %     % get mask for chunk values 1 and 4
    %     msk=cosmo_match(ds.sa.chunks,[1 4]);
    %
    %     % get mask for chunk values 1 and 4, and target values 1, 3 or 6.
    %     msk=cosmo_match(ds.sa.chunks, [1 4], ds.sa.targets, [1 3 6]);
    %
    %     % get feature mask for the fourth channel in an MEEG dataset
    %     msk=cosmo_match(ds.fa.chan,4)
    %
    % Notes:
    %   - the output of this function can be used with cosmo_slice
    %     to select features or samples of interest
    %   - to select feature dimension values in an fmri or meeg dataset
    %     (e.g., channel selection), see cosmo_dim_match
    %
    % See also: cosmo_slice, cosmo_stack, cosmo_dim_match
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #

    if nargin < 2
        error('Need at least two input arguments');
    elseif mod(nargin, 2) ~= 0
        error('Need an even number of input arguments');
    end

    if ischar(needle)
        needle = {needle};
    end

    msk = match(needle, haystack);

    for argpos = 1:2:(nargin - 2)
        other_msk = match(varargin{argpos + [1, 0]});

        % check the size is the same
        if numel(msk) ~= numel(other_msk)
            error('Illegal conjunction: masks have %d ~= %d elements', ...
                  numel(msk), numel(other_msk));
        end

        % conjunction
        msk = msk & other_msk;
    end

function msk = match(needle, haystack)
    % wrapper function that calls match_{scalar,functio_handle,vectors}
    if isnumeric(haystack) && isnumeric(needle) && numel(needle) == 1
        % optimization for most standard case: vector and scalar
        msk = match_numeric_scalar(needle, haystack);
    elseif ~iscell(needle) && ~ischar(needle) && ~isnumeric(needle) && ...
                                            isa(needle, 'function_handle')
        % function handle
        msk = match_function_handle(needle, haystack);
    else
        msk = match_vectors(needle, haystack);
    end

function msk = match_numeric_scalar(needle, haystack)
    % needle is a numeric scalar
    check_vec_or_empty(haystack);
    msk = needle == haystack;

function msk = match_function_handle(func, data)
    check_vec_or_empty(data);
    if iscell(data)
        msk = cellfun(func, data, 'UniformOutput', true);
    else
        msk = func(data);
    end

    if ~islogical(msk)
        error('function %s should return boolean array', func2str(func));
    end

function msk = match_vectors(needle, haystack)
    nrows = check_vec_or_empty(needle);
    ncols = check_vec_or_empty(haystack);

    if isnumeric(needle) && isnumeric(haystack)
        matches = bsxfun(@eq, needle(:), haystack(:)');
    elseif iscell(needle) && iscell(haystack)
        matches = false(nrows, ncols);

        for k = 1:ncols
            if ~ischar(haystack{k})
                error('cell must contain strings');
            end
        end

        for k = 1:nrows
            needlek = needle{k};
            if ~ischar(needlek)
                error('cell must contain strings');
            end
            match_indices = strcmp(needlek, haystack);
            matches(k, match_indices) = true;
        end
    else
        error(['Illegal inputs %s and %s: need numeric arrays or '...
               'cell with strings'], class(needle), class(haystack));
    end

    if isempty(matches)
        msk = false(size(haystack));
    else
        msk = reshape(sum(matches, 1) > 0, size(haystack));
    end

function n = check_vec_or_empty(x)
    % this function returns the number of elements of the input
    % it returns true if x is a vector or if x is empty
    n = numel(x);
    if n > 0 && ~isvector(x)
        error('Input argument is not a vector');
    end