function bal_partitions = cosmo_balance_partitions(partitions, ds, varargin)
% balances a partition so that each target occurs equally often in each
% training and test chunk
%
% bpartitions=cosmo_balance_partitions(partitions, ds, ...)
%
% Inputs:
% partitions struct with fields:
% .train_indices } Each is a 1xN cell (for N chunks) containing the
% .test_indices } sample indices for each partition
% ds dataset struct with field .sa.targets.
% 'nrepeats',nr Number of repeats (default: 1). The output will
% have nrep as many partitions as the input set. This
% option is not compatible with 'nmin'.
% 'nmin',nm Ensure that each sample occurs at least
% nmin times in each training set (some samples may
% be repeated more often than than). This option is not
% compatible with 'nrepeats'.
% 'balance_test' If set to false, indices in the test set are not
% necessarily balanced. The default is true.
% 'seed',sd Use seed sd for pseudoo-random number generation.
% Different values lead almost always to different
% pseudo-random orders. To disable using a seed - which
% causes this function to give different results upon
% subsequent calls with identical inputs - use sd=0.
%
% Output:
% bpartitions similar struct as input partitions, except that
% - each field is a 1x(N*nsets) cell
% - each unique target is represented about equally often
% - each target in each training chunk occurs equally
% often
%
% Examples:
% % generate a simple dataset with unbalanced partitions
% ds=struct();
% ds.samples=zeros(9,2);
% ds.sa.targets=[1 1 2 2 2 3 3 3 3]';
% ds.sa.chunks=[1 2 2 1 1 1 2 2 2]';
% p=cosmo_nfold_partitioner(ds);
% %
% % show original (unbalanced) partitioning
% cosmo_disp(p);
% %|| .train_indices
% %|| { [ 2 [ 1
% %|| 3 4
% %|| 7 5
% %|| 8 6 ]
% %|| 9 ] }
% %|| .test_indices
% %|| { [ 1 [ 2
% %|| 4 3
% %|| 5 7
% %|| 6 ] 8
% %|| 9 ] }
% %
% % make standard balancing (nsets=1); some targets are not used
% q=cosmo_balance_partitions(p,ds);
% cosmo_disp(q);
% %|| .train_indices
% %|| { [ 2 [ 1
% %|| 3 5
% %|| 7 ] 6 ] }
% %|| .test_indices
% %|| { [ 1 [ 2
% %|| 5 3
% %|| 6 ] 7 ] }
% %
% % make balancing where each sample in each training fold is used at
% % least once
% q=cosmo_balance_partitions(p,ds,'nmin',1);
% cosmo_disp(q);
% %|| .train_indices
% %|| { [ 2 [ 2 [ 2 [ 1 [ 1
% %|| 3 3 3 5 4
% %|| 7 ] 9 ] 8 ] 6 ] 6 ] }
% %|| .test_indices
% %|| { [ 1 [ 1 [ 1 [ 2 [ 2
% %|| 5 4 5 3 3
% %|| 6 ] 6 ] 6 ] 7 ] 9 ] }
% %
% % triple the number of partitions and sample from training indices
% q=cosmo_balance_partitions(p,ds,'nrepeats',3);
% cosmo_disp(q);
% %|| .train_indices
% %|| { [ 2 [ 2 [ 2 [ 1 [ 1 [ 1
% %|| 3 3 3 5 4 5
% %|| 7 ] 9 ] 8 ] 6 ] 6 ] 6 ] }
% %|| .test_indices
% %|| { [ 1 [ 1 [ 1 [ 2 [ 2 [ 2
% %|| 5 4 5 3 3 3
% %|| 6 ] 6 ] 6 ] 7 ] 9 ] 8 ] }
%
% Notes:
% - this function is intended for datasets where the number of
% samples across targets is not equally distributed. A typical
% application is MEEG datasets.
% - By default both the train and test indices are balanced, so that
% chance accuracy is equal to the inverse of the number of unique
% targets (1/C with C the number of classes).
% Balancing is considered a *Good Thing*:
% * Suppose the entire dataset has 75% samples of
% class A and 25% samples of class B, but the data does not contain
% any information that allows for discrimination between the classes.
% A classifier trained on a subset may always predict the class that
% occurred most often in the training set, which is class A. If the test
% set also contains 75% of class A, then classification accuracy would
% be 75%, which is higher than 1/2 (with 2 the number of classes).
% * Balancing the training set only would accommodate this issue, but it
% may still be the case that a classifier consistently predicts one
% class more often than other classes. While this may be unbiased with
% respect to predictions of one particular class over many dataset
% instances, it could lead to biases (either above or below chance)
% in particular instances.
%
% See also: cosmo_nchoosek_partitioner, cosmo_nfold_partitioner
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
defaults = struct();
defaults.seed = 1;
defaults.balance_test = true;
params = cosmo_structjoin(defaults, varargin);
cosmo_check_partitions(partitions, ds, 'unbalanced_partitions_ok', true);
classes = unique(ds.sa.targets);
nfolds_in = numel(partitions.train_indices);
train_indices_out = cell(1, nfolds_in);
test_indices_out = cell(1, nfolds_in);
for j = 1:nfolds_in
tr_idx = partitions.train_indices{j};
te_idx = partitions.test_indices{j};
tr_targets = ds.sa.targets(tr_idx);
[tr_fold_classes, tr_fold_class_pos] = get_classes(tr_targets);
if ~isequal(tr_fold_classes, classes)
missing = setdiff(classes, tr_fold_classes);
error('missing training class %d in fold %d', missing(1), j);
end
% see how many output folds for the current input fold
nfolds_out = get_nfolds_out(tr_fold_class_pos, params);
train_indices_out{j} = sample_indices(tr_idx, tr_fold_class_pos, ...
nfolds_out, params);
if params.balance_test
te_targets = ds.sa.targets(te_idx);
[te_fold_classes, te_fold_class_pos] = get_classes(te_targets);
if ~isequal(te_fold_classes, classes)
missing = setdiff(classes, te_fold_classes);
error('missing test class %d in fold %d', missing(1), j);
end
test_indices_out{j} = sample_indices(te_idx, te_fold_class_pos, ...
nfolds_out, params);
else
test_indices_out{j} = repmat({te_idx}, 1, nfolds_out);
end
end
bal_partitions = struct();
bal_partitions.train_indices = cat(2, train_indices_out{:});
bal_partitions.test_indices = cat(2, test_indices_out{:});
bal_partitions = ensure_sorted_indices(bal_partitions);
cosmo_check_partitions(bal_partitions, ds);
function partitions = ensure_sorted_indices(partitions)
fns = {'train_indices', 'test_indices'};
for k = 1:numel(fns)
fn = fns{k};
idx_cell = partitions.(fn);
for j = 1:numel(idx_cell)
idx = idx_cell{j};
if ~issorted(idx)
partitions.(fn){j} = sort(idx);
end
end
end
function tr_folds_out = sample_indices(target_idx, fold_class_pos, ...
nfolds_out, params)
% sample from the indices
tr_folds_out_indices = sample_class_pos(fold_class_pos, ...
nfolds_out, params);
% assign training indices
tr_folds_out = cell(1, nfolds_out);
for k = 1:nfolds_out
tr_folds_out{k} = target_idx(tr_folds_out_indices{k});
end
function [classes, class_pos] = get_classes(targets)
[class_pos, targets_cell] = cosmo_index_unique({targets});
classes = targets_cell{1};
function nfolds = get_nfolds_out(class_pos, params)
% return how many folds are needed based on the sample indices for each
% class
if isfield(params, 'nmin')
if isfield(params, 'nrepeats')
error(['options ''nmin'' and nrepeats'' are '...
'mutually exclusive']);
else
targets_hist = cellfun(@numel, class_pos);
nsamples_ratio = max(targets_hist) / min(targets_hist);
nfolds = ceil(nsamples_ratio) * params.nmin;
end
elseif isfield(params, 'nrepeats')
nfolds = params.nrepeats;
else
nfolds = 1;
end
function folds = sample_class_pos(class_pos, nfolds, params)
% return nfolds folds, each with a sample from class_pos
nclasses = numel(class_pos);
class_count = cellfun(@numel, class_pos);
nsamples_per_class = min(class_count);
boundaries = [0; cumsum(class_count)];
nsamples = boundaries(end);
% single call to generate pseudo-random uniform data
uniform_random_all = cosmo_rand(nsamples, 1, 'seed', params.seed);
idxs = cell(nfolds, nclasses);
% process each fold separately
for k = 1:nclasses
uniform_random_pos = (boundaries(k) + 1):boundaries(k + 1);
[foo, i] = sort(uniform_random_all(uniform_random_pos));
nrepeats = ceil(nsamples_per_class * nfolds / numel(i));
% build sequence by repeating the random indices as many times as
% necessary
seq = repmat(i, 1, nrepeats);
for j = 1:nfolds
if k == 1
idxs{j} = cell(1, nclasses);
end
seq_idx = nsamples_per_class * (j - 1) + (1:nsamples_per_class);
idxs{j, k} = class_pos{k}(seq(seq_idx));
end
end
folds = cell(1, nfolds);
for j = 1:nfolds
folds{j} = cat(1, idxs{j, :});
end