cosmo average samples

function ds_avg = cosmo_average_samples(ds, varargin)
    % average subsets of samples by unique combinations of sample attributes
    %
    % ds_avg=cosmo_average_samples(ds, ...)
    %
    % Inputs:
    %   ds              dataset struct with field:
    %     .samples      NS x NF
    %     .sa           with fields .targets and .chunks
    %   'ratio', ratio  ratio (between 0 and 1) of samples to select for
    %                   each average. Not compatible with 'count' (default: 1).
    %   'count', c      number of samples to select for each average.
    %                   Not compatible with 'ratio'.
    %   'resamplings',s Maximum number of times each sample in ds is used for
    %                   averaging. Not compatible with 'repeats' (default: 1)
    %   'repeats', r    Number of times an average is computed for each unique
    %                   combination of targets and chunks. Not compatible with
    %                   'resamplings'
    %   'seed', d       Use seed d for pseudo-random sampling (optional); d
    %                   can be any integer between 1 and 2^32-1.
    %                   If this option is omitted, then different calls to this
    %                   function may (usually: will) return different results.
    %   'split_by',fs   A cell with fieldnames by which the dataset is split
    %                   prior to averaging each bin.
    %                   Default: {'targets','chunks'}.
    %
    %
    % Returns
    %   ds_avg          dataset struct with field:
    %      .samples     ('repeats'*unq) x NF, where
    %                   unq is the number of unique combinations of values in
    %                   sample attribute as indicated by 'split_by' (by
    %                   default, data is split by 'targets' and 'chunks').
    %                   Each sample is an average from samples that share the
    %                   same values for these attributes. The number of times
    %                   each sample is used to compute average values differs
    %                   by one at most.
    %      .sa          Based on averaged samples.
    %      .fa,.a       Same as in ds (if present).
    %
    % Examples:
    %     % generate simple dataset with 3 times (2 targets x 3 chunks)
    %     ds=cosmo_synthetic_dataset('nreps',3);
    %     size(ds.samples)
    %     %|| [ 18 6 ]
    %     cosmo_disp([ds.sa.targets ds.sa.chunks])
    %     %|| [ 1         1
    %     %||   2         1
    %     %||   1         2
    %     %||   :         :
    %     %||   2         2
    %     %||   1         3
    %     %||   2         3 ]@18x2
    %     % average each unique combination of chunks and targets
    %     ds_avg=cosmo_average_samples(ds);
    %     cosmo_disp([ds_avg.sa.targets ds_avg.sa.chunks]);
    %     %|| [ 1         1
    %     %||   1         2
    %     %||   1         3
    %     %||   2         1
    %     %||   2         2
    %     %||   2         3 ]
    %     %
    %     % for each unique target-chunk combination, select 50% of the samples
    %     % randomly and average these; repeat the random selection process 4
    %     % times. Each sample in 'ds' is used twice (=.5*4) as an element
    %     % to compute an average. The output has 24 samples
    %     ds_avg2=cosmo_average_samples(ds,'ratio',.5,'repeats',4);
    %     cosmo_disp([ds_avg2.sa.targets ds_avg2.sa.chunks],'edgeitems',5);
    %     %|| [ 1         1
    %     %||   1         1
    %     %||   1         1
    %     %||   1         1
    %     %||   1         2
    %     %||   :         :
    %     %||   2         2
    %     %||   2         3
    %     %||   2         3
    %     %||   2         3
    %     %||   2         3 ]@24x2
    %
    % Notes:
    %  - this function averages feature-wise; the output has the same features
    %    as the input.
    %  - it can be used to average data from trials safely without circular
    %    analysis issues.
    %  - as a result the number of trials in each chunk and target is
    %    identical, so balancing of partitions is not necessary for data from
    %    this function.
    %  - the default behaviour of this function computes a single average for
    %    each unique combination of chunks and targets.
    %  - if the number of samples differs for different combinations of chunks
    %    and targets, then some samples may not be used to compute averages,
    %    as the least number of samples across combinations is used to set
    %  - As illustration, consider a dataset with the following number of
    %    samples for each unique targets and chunks combination
    %
    %    .sa.chunks     .sa.targets         number of samples
    %    ----------     -----------         -----------------
    %       1               1                   12
    %       1               2                   16
    %       2               1                   15
    %       2               2                   24
    %
    %    The least number of samples is 12, which determines how many averages
    %    are computed. Different parameters result in a different number of
    %    averages; some examples:
    %
    %       parameters                      number of output samples for each
    %                                       combination of targets and chunks
    %       ----------                      ---------------------------------
    %       'count', 2                      6 averages from 2 samples [*]
    %       'count', 3                      4 averages from 3 samples [*]
    %       'ratio', .25                    4 averages from 3 samples [*]
    %       'ratio', .5                     2 averages from 6 samples [*]
    %       'ratio', .5, 'repeats', 3       6 averages from 6 samples
    %       'ratio', .5, 'resamplings', 3   12 averages from 6 samples
    %
    %    [*]: not all samples in the input are used to compute averages from
    %         the output.
    %
    %    Briefly, 'ratio' or 'count' determine, together with the least number
    %    of samples, how many samples are averaged for each output sample.
    %    'resamplings' and 'repeats' determine how many averages are taken,
    %    based on how many samples are averaged for each output sample.
    % -  To compute averages based on other sample attributes than 'targets'
    %    and 'chunks', use the 'split_by' option
    %
    % See also: cosmo_balance_partitions
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #

    % deal with input parameters

    defaults = struct();
    defaults.seed = [];
    defaults.split_by = {'targets', 'chunks'};

    opt = cosmo_structjoin(defaults, varargin);
    split_idxs = get_split_indices(ds, opt);

    nsplits = numel(split_idxs);
    bin_counts = cellfun(@numel, split_idxs);

    [split_sample_ids, nrepeat] = get_split_sample_ids(bin_counts, opt);

    nfeatures = size(ds.samples, 2);

    mu = zeros(nrepeat * nsplits, nfeatures);
    slice_ids = zeros(nrepeat * nsplits, 1);

    row = 0;
    for k = 1:nsplits
        split_idx = split_idxs{k};
        split_ids = split_sample_ids{k};

        for j = 1:nrepeat
            sample_ids = split_idx(split_ids(:, j));

            row = row + 1;
            mu(row, :) = mean(ds.samples(sample_ids, :), 1);
            slice_ids(row) = sample_ids(1);
        end
    end

    ds_avg = cosmo_slice(ds, slice_ids, 1, false);
    ds_avg.samples = mu;

function split_idxs = get_split_indices(ds, opt)
    persistent cached_sa
    persistent cached_opt
    persistent cached_split_idxs

    if ~(isstruct(ds) && ...
         isfield(ds, 'samples') && ...
         isfield(ds, 'sa'))
        error(['First argument must be a dataset struct field fields '...
               'samples and sa']);
    end

    if ~(isequal(cached_opt, opt) && ...
         isequal(cached_sa, ds.sa))
        split_by = opt.split_by;
        if ~iscellstr(split_by)
            error('''split_by''  must be a cell with strings');
        end

        n_dim = numel(split_by);
        if n_dim == 0
            cached_split_idxs = {(1:size(ds.samples, 1))'};
        else
            values = cell(n_dim, 1);
            for k = 1:numel(split_by)
                key = split_by{k};
                if ~isfield(ds.sa, key)
                    error('Missing field ''%s'' in .sa', key);
                end
                values{k} = ds.sa.(key);
            end
            cached_split_idxs = cosmo_index_unique(values);
        end

        cached_sa = ds.sa;
        cached_opt = opt;
    end

    split_idxs = cached_split_idxs;

function [idx, value] = get_mutually_exclusive_param(opt, names, ...
                                                     default_idx, default_value)
    idx = [];
    value = [];

    n = numel(names);

    for k = 1:n
        key = names{k};
        if isfield(opt, key)
            value = opt.(key);
            if ~isempty(value)
                if isempty(idx)
                    idx = k;
                else
                    error(['The options ''%s'' and ''%s'' are mutually '...
                           'exclusive '], key, names{idx});
                end
            end
        end
    end

    if isempty(idx)
        idx = default_idx;
        value = default_value;
    end

function [nselect, nrepeat] = get_selection_params(bin_counts, opt)
    [idx, value] = get_mutually_exclusive_param(opt, {'ratio', 'count'}, 1, 1);

    switch idx
        case 1
            % ratio
            nselect = round(value * min(bin_counts));
        case 2
            % count
            nselect = value;

    end

    ensure_in_range('Number of elements to select', nselect, ...
                    1, min(bin_counts));

    repeat_labels = {'resamplings', 'repeats'};
    [idx2, value2] = get_mutually_exclusive_param(opt, repeat_labels, 1, 1);
    switch idx2
        case 1
            nrepeat = floor(value2 * min(bin_counts ./ nselect));
        case 2
            nrepeat = value2;
    end

    ensure_in_range('Number of repeats', ...
                    nrepeat, 1, Inf);

function ensure_in_range(label, val, min_val, max_val)
    postfix = [];
    while true
        if ~isscalar(val) || ~isnumeric(val)
            postfix = 'must be numeric scalar';
            break
        end

        if round(val) ~= val
            postfix = 'must be an integer';
        end

        if val < min_val
            postfix = sprintf('cannot be less than %d', min_val);
            break
        end

        if val > max_val
            postfix = sprintf('cannot be greater than %d', max_val);
            break
        end

        break
    end

    if ~isempty(postfix)
        msg = [label ' ' postfix];
        error(msg);
    end

function [sample_ids, nrepeat] = get_split_sample_ids(bin_counts, opt)
    [nselect, nrepeat] = get_selection_params(bin_counts, opt);

    % number of saples for each unique chunks-targets combination
    nsplits = numel(bin_counts);

    % allocate space for output
    sample_ids = cell(nsplits, 1);

    % select samples randomly, but in a manner so that each one is used
    % approximately equally often
    for k = 1:nsplits
        bin_count = bin_counts(k);
        sample_ids{k} = cosmo_sample_unique(nselect, bin_count, nrepeat, opt);
    end