function test_suite = test_stat
% tests for cosmo_stat
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
try % assignment of 'localfunctions' is necessary in Matlab >= 2016
test_functions = localfunctions();
catch % no problem; early Matlab versions can use initTestSuite fine
end
initTestSuite;
function r = randint(x)
r = ceil(rand() * x + 10);
function test_stat_correspondence
is_matlab = cosmo_wtf('is_matlab');
if isempty(which('ttest'))
cosmo_notify_test_skipped('ttest is not available');
return
end
% test conformity with matlab's stat functions
ntargets = randint(5);
nchunks = randint(5);
ds = cosmo_synthetic_dataset('nchunks', nchunks, 'ntargets', ntargets, 'sigma', 0);
ds.samples = randn(size(ds.samples)); % full random data
ds = cosmo_slice(ds, [2 5 6], 2);
[ns, nf] = size(ds.samples);
f = zeros(1, nf);
p = zeros(1, nf);
for k = 1:nf
if is_matlab || isempty(which('anova'))
[p(k), tab] = anova1(ds.samples(:, k), ds.sa.targets, 'off');
f(k) = tab{2, 5};
df = [tab{2:3, 3}];
else
[p(k), f(k), df_b, df_w] = anova(ds.samples(:, k), ds.sa.targets);
df = [df_b df_w];
end
end
% f stat
ds.sa.chunks = (1:ns)';
ff = cosmo_stat(ds, 'F');
assertVectorsAlmostEqual(f, ff.samples);
assertEqual(ff.sa.stats, {sprintf('Ftest(%d,%d)', df)});
pp = cosmo_stat(ds, 'F', 'p');
assertVectorsAlmostEqual(p, pp.samples);
ds.sa.chunks(:) = 1;
assertExceptionThrown(@()cosmo_stat(ds, 'F'), '');
% t stat
tails = {'p', 'left', 'right', 'both'};
for k = 1:numel(tails)
% one-sample ttest
assertExceptionThrown(@()cosmo_stat(ds, 't'), '');
ds1 = ds;
ds1.sa.targets(:) = 10;
ds1.sa.chunks = (1:ns)';
tail = tails{k};
if strcmp(tail, 'p')
ttest_arg = cell(0);
else
ttest_arg = {'tail', tail};
end
% test t-statistic
[h, p, ci, stats] = ttest_wrapper(ds.samples, 0, ttest_arg{:});
tt = cosmo_stat(ds1, 't');
assertVectorsAlmostEqual(stats.tstat, tt.samples);
assertEqual(tt.sa.stats, {sprintf('Ttest(%d)', stats.df(1))});
pp = cosmo_stat(ds1, 't', tail);
assertVectorsAlmostEqual(p, pp.samples);
ds1.sa.chunks(:) = 1;
assertExceptionThrown(@()cosmo_stat(ds1, 't'), '');
% two-sample (unpaired) ttest
ds2 = ds;
i = randperm(ns)';
ds2.sa.targets = mod(i, 2) + 1;
ds2.sa.chunks = i;
ds_sp = cosmo_split(ds2, 'targets');
x = ds_sp{1}.samples;
y = ds_sp{2}.samples;
[h, p, ci, stats] = ttest2_wrapper(x, y, ttest_arg{:});
tt = cosmo_stat(ds2, 't2');
assertVectorsAlmostEqual(stats.tstat, tt.samples);
assertEqual(tt.sa.stats, {sprintf('Ttest(%d)', stats.df(1))});
pp = cosmo_stat(ds2, 't2', tail);
assertVectorsAlmostEqual(p, pp.samples);
ds2.sa.chunks(1) = ds2.sa.chunks(1) + 1;
assertExceptionThrown(@()cosmo_stat(ds1, 't2'), '');
ds2.sa.chunks(1) = ds2.sa.chunks(1) - 1;
ds2.sa.targets(1) = ds2.sa.targets(1) + 1;
assertExceptionThrown(@()cosmo_stat(ds1, 't2'), '');
ds2.sa.targets(:) = 1;
assertExceptionThrown(@()cosmo_stat(ds1, 't2'), '');
end
assertExceptionThrown(@()cosmo_stat(ds, 't2'), '');
assertExceptionThrown(@()cosmo_stat(ds, 't'), '');
function test_stat_contrast()
ds = cosmo_synthetic_dataset('nchunks', 6, 'ntargets', 4, 'sigma', 0);
ds.sa.contrast = zeros(size(ds.sa.targets));
ds.sa.contrast(ds.sa.targets == 2) = 1;
ds.sa.contrast(ds.sa.targets == 4) = -1;
chunks = ds.sa.chunks;
ds.sa.chunks = chunks * 6 + ds.sa.targets;
res = cosmo_stat(ds, 'F', 'z');
assertElementsAlmostEqual(res.samples, ...
[0.2695 0.9675 -0.8770 -1.0542 0.9173 1.2814], ...
'absolute', 1e-4);
% test exceptions
aet = @(varargin)assertExceptionThrown(@() ...
cosmo_stat(varargin{:}), '');
ds2 = cosmo_slice(ds, ds.sa.contrast ~= 0);
aet(ds2, 't');
aet(ds2, 't2');
ds.sa.contrast(1) = 1;
aet(ds, 'F');
ds.sa.contrast(ds.sa.targets == 1) = 1;
aet(ds, 'F');
% within subject F
ds.sa.chunks = chunks;
aet(ds, 'F');
aet(ds, 't');
aet(ds, 't2');
function [h, p, ci, stats] = ttest_wrapper(varargin)
[h, p, ci, stats] = general_ttestX_wrapper(@ttest, varargin{:});
function [h, p, ci, stats] = ttest2_wrapper(varargin)
[h, p, ci, stats] = general_ttestX_wrapper(@ttest2, varargin{:});
function [h, p, ci, stats] = general_ttestX_wrapper(func, varargin)
args = varargin;
switch nargin(func)
case {5, 6}
% old Matlab
args = remove_keys_from_arguments(2, {'alpha', 'tail', 'dim'}, args);
case -3
% GNU Octave and recent Matlab
otherwise
assert(false);
end
[h, p, ci, stats] = func(args{:});
function short_args = remove_keys_from_arguments(skip_count, keys, args)
n = numel(keys);
short_args = cell(1, skip_count + n);
short_args(1:skip_count) = args(1:skip_count);
for k = 1:n
key = keys{k};
i = strmatch(key, args((skip_count + 1):2:end));
if isempty(i)
short_arg = [];
else
short_arg = args{skip_count + i * 2};
end
short_args{skip_count + k} = short_arg;
end
function test_stat_no_division_by_zero_error()
[lastmsg, lastid] = lastwarn();
cleaner = onCleanup(@()lastwarn(lastmsg, lastid));
lastwarn('');
ds = cosmo_synthetic_dataset('ntargets', 1, 'nchunks', 1);
ds.samples(:) = 0;
cosmo_stat(ds, 't', 'z');
assertEqual(lastwarn(), '');
function test_stat_same_nan_fstat()
ds = cosmo_synthetic_dataset('nchunks', 4, 'ntargets', 3);
ds.samples(ds.sa.chunks == 3, [1 3 5]) = NaN;
stat = cosmo_stat(ds, 'F'); % should be ok
assertEqual(isnan(stat.samples), any(isnan(ds.samples), 1));
function test_stat_exceptions()
ds = cosmo_synthetic_dataset('ntargets', 3);
aet = @(varargin)assertExceptionThrown(@() ...
cosmo_stat(varargin{:}), '');
aet(ds, 'foo');
aet(ds, 'F', 'foo');
aet(ds, 't');
ds2 = cosmo_slice(ds, ds.sa.targets == 1);
aet(ds2, 'F');
aet(ds2, 't2');
function test_stat_missing_values()
% assuming that cosmo_stat works well with no NaNs, try it
% with some missing values
output_labels = {'', 'p', 'z', 'left', 'right'};
n_outputs = numel(output_labels);
for k = 1:5
switch k
% sa_label_func (below) takes the number of chunks and
% returns a statistic sa label
case 1
stat_name = 't';
ntargets = 1;
is_between = true;
sa_label_func = @(x)sprintf('Ttest(%d)', x - 1);
case 2
stat_name = 't';
ntargets = 2;
is_between = false;
sa_label_func = @(x)sprintf('Ttest(%d)', x - 1);
case 3
stat_name = 't2';
ntargets = 2;
is_between = true;
sa_label_func = @(x)sprintf('Ttest(%d)', x - 2);
case 4
stat_name = 'F';
ntargets = randint(5);
is_between = true;
sa_label_func = @(x)sprintf('Ftest(%d,%d)', ...
ntargets - 1, x - ntargets);
case 5
stat_name = 'F';
ntargets = randint(5);
is_between = false;
sa_label_func = @(x)sprintf('Ftest(%d,%d)', ...
ntargets - 1, x * (ntargets - 1) - ntargets + 1);
otherwise
assert(false);
end
nchunks = randint(10);
% make dataset with some elements set to NaN
ds = cosmo_synthetic_dataset('ntargets', ntargets, ...
'nchunks', nchunks);
if is_between
nchunks = numel(ds.sa.chunks);
ds.sa.chunks(:) = 1:nchunks;
end
[nsamples, nfeatures] = size(ds.samples);
% at least one column has no NaN values
attempt = 100;
while true
attempt = attempt - 1;
if attempt == 0
error('unable to generate data');
end
% add some NaNs
nan_ratio = .1 + rand() * .2;
non_nan_col = ceil(1 + rand() * (nfeatures - 1));
for col = 1:nfeatures
if col == non_nan_col
continue
elseif col == 1
% at least one NaN
ds.samples(ds.sa.chunks == 1) = NaN;
else
% subset of chunks set to NaN
[unused, idx] = sort(rand(1, nchunks));
nan_chunks = idx(1:ceil(nan_ratio * nchunks));
msk = any(bsxfun(@eq, nan_chunks, ds.sa.chunks), 2);
ds.samples(msk, col) = NaN;
end
end
% verify there are some NaNs
has_nan = any(isnan(ds.samples(:)));
not_all_are_nan = any(any(~isnan(ds.samples), 1));
if has_nan && not_all_are_nan
break
end
end
for i_output = 1:n_outputs
output_label = output_labels{i_output};
result_all = cosmo_stat(ds, stat_name, output_label);
% check values for each feature
for col = 1:nfeatures
ds_full = cosmo_slice(ds, col, 2);
ds_non_nan = cosmo_slice(ds_full, ...
~isnan(ds_full.samples));
result = cosmo_stat(ds_non_nan, stat_name, output_label);
assertEqual(size(result.samples), [1 1]);
if isempty(output_label) && any(isnan(ds_full.samples))
expected_value = NaN;
else
expected_value = result.samples;
end
assertElementsAlmostEqual(result_all.samples(:, col), ...
expected_value);
end
% check sa
return_raw_stat = isempty(output_label);
if return_raw_stat
% generate label with 'Ttest' or 'Ftest'
sa_label = sa_label_func(nchunks);
else
if strcmp(output_label, 'z')
sa_label = 'Zscore()';
else
sa_label = 'Pval()';
end
end
expected_sa = struct();
expected_sa.stats = {sa_label};
assertEqual(result_all.sa, expected_sa);
% ensure that if targets for one condition are all NaN, then
% output is NaN
nan_ds = ds;
msk = nan_ds.sa.targets == max(nan_ds.sa.targets);
nan_ds.samples(msk, :) = NaN;
result_all_nan = cosmo_stat(nan_ds, stat_name, output_label);
assertEqual(result_all_nan.samples, zeros(1, nfeatures) + NaN);
% set one target to NaN; if unbalance the resulting
% sample must be NaN
tiny_ds = cosmo_slice(ds, 1, 2);
tiny_ds.samples = randn(size(tiny_ds.samples));
result = cosmo_stat(tiny_ds, stat_name, output_label);
assert(~isnan(result.samples));
msk = tiny_ds.sa.chunks == 1 & tiny_ds.sa.targets == 1;
assert(sum(msk) == 1);
tiny_ds.samples(msk) = NaN;
result = cosmo_stat(tiny_ds, stat_name, output_label);
must_be_nan = return_raw_stat || ~is_between;
assertEqual(must_be_nan, isnan(result.samples));
end
end
function test_stat_extreme_values()
base_ds = cosmo_synthetic_dataset('ntargets', 1, ...
'nchunks', 20, ...
'size', 'normal');
for signal_sign = [-1, 1]
magnitude = 10;
ds = base_ds;
ds.samples = randn(size(ds.samples)) + signal_sign * magnitude;
stat_ds = cosmo_stat(ds, 't');
tz = stat_ds.samples;
assertTrue(all(tz * signal_sign > 0));
ds = cosmo_stat(ds, 't', 'z');
zs = ds.samples;
assertTrue(all(isinf(zs)));
assertTrue(all(zs * signal_sign > 0));
end
function test_stat_regression()
% using pre-generated data
ds = cosmo_synthetic_dataset('nchunks', 6, 'ntargets', 4, 'sigma', 0);
ds = cosmo_slice(ds, [2 6], 2);
params = get_stat_regression_params();
for k = 1:numel(params)
param = params{k};
args = param{1};
targets = param{2}{1};
is_between = param{2}{2};
should_raise_error = param{2}{3};
ds_sel = cosmo_slice(ds, cosmo_match(ds.sa.targets, targets));
if is_between
ds_sel.sa.chunks = ds_sel.sa.chunks + 6 * ds_sel.sa.targets;
end
stat_func = @()cosmo_stat(ds_sel, args{:});
if should_raise_error
assertExceptionThrown(stat_func, '');
continue
end
result = stat_func();
samples = param{3}{1};
assertElementsAlmostEqual(result.samples, samples, 'absolute', 1e-4);
stat_sa = struct();
stat_sa.stats = param{3}(2);
assertEqual(result.sa, stat_sa);
% test errors for wrong assignment of chunks and targets
if is_between
ds_sel.sa.chunks(1) = ds_sel.sa.chunks(2);
assertExceptionThrown(@()cosmo_stat(ds_sel, args{:}), '');
else
ds_sel.sa.chunks(ds_sel.sa.chunks == 2) = 1;
assertExceptionThrown(@()cosmo_stat(ds_sel, args{:}), '');
end
end
function params = get_stat_regression_params()
% parameters to test regressions in stat test; format:
% {{stat_name, output_stat_name},...
% {targets,is_between_test,should_raise_error},...
% {samples,.sa.stats{}}}
%
% based on input dataset generated by:
% ds=cosmo_synthetic_dataset('nchunks',6,'ntargets',4,'sigma',0);
% ds=cosmo_slice(ds,[2 6],2);
params = {{{'F', ''}, ...
{[1 2 3 4], 0, 0}...
{[1.63475 2.00315], 'Ftest(3,15)'}}
{{'F', 'z'}, ...
{[1 2 3 4], 0, 0}...
{[0.76051 1.00751], 'Zscore()'}}
{{'F', 'p'}, ...
{[1 2 3 4], 0, 0}...
{[0.22348 0.15684], 'Pval()'}}
{{'F', 'left'}, ...
{[1 2 3 4], 0, 0}...
{[0.77652 0.84316], 'Pval()'}}
{{'F', 'right'}, ...
{[1 2 3 4], 0, 0}...
{[0.22348 0.15684], 'Pval()'}}
{{'F', 'both'}, ...
{[1 2 3 4], 0, 0}...
{[0.44695 0.31369], 'Pval()'}}
{{'t2', ''}, ...
{[3 4], 0, 1}...
'<should raise error>'}
{{'t2', 'z'}, ...
{[3 4], 0, 1}...
'<should raise error>'}
{{'t2', 'p'}, ...
{[3 4], 0, 1}...
'<should raise error>'}
{{'t2', 'left'}, ...
{[3 4], 0, 1}...
'<should raise error>'}
{{'t2', 'right'}, ...
{[3 4], 0, 1}...
'<should raise error>'}
{{'t2', 'both'}, ...
{[3 4], 0, 1}...
'<should raise error>'}
{{'t', ''}, ...
{[3 4], 0, 0}...
{[-0.47494 -0.73292], 'Ttest(5)'}}
{{'t', 'z'}, ...
{[3 4], 0, 0}...
{[-0.44704 -0.68000], 'Zscore()'}}
{{'t', 'p'}, ...
{[3 4], 0, 0}...
{[0.65485 0.49650], 'Pval()'}}
{{'t', 'left'}, ...
{[3 4], 0, 0}...
{[0.32742 0.24825], 'Pval()'}}
{{'t', 'right'}, ...
{[3 4], 0, 0}...
{[0.67258 0.75175], 'Pval()'}}
{{'t', 'both'}, ...
{[3 4], 0, 0}...
{[0.65485 0.49650], 'Pval()'}}
{{'F', ''}, ...
{[1 2 3 4], 1, 0}...
{[1.84801 1.07461], 'Ftest(3,20)'}}
{{'F', 'z'}, ...
{[1 2 3 4], 1, 0}...
{[0.95018 0.29948], 'Zscore()'}}
{{'F', 'p'}, ...
{[1 2 3 4], 1, 0}...
{[0.17101 0.38229], 'Pval()'}}
{{'F', 'left'}, ...
{[1 2 3 4], 1, 0}...
{[0.82899 0.61771], 'Pval()'}}
{{'F', 'right'}, ...
{[1 2 3 4], 1, 0}...
{[0.17101 0.38229], 'Pval()'}}
{{'F', 'both'}, ...
{[1 2 3 4], 1, 0}...
{[0.34202 0.76457], 'Pval()'}}
{{'t2', ''}, ...
{[3 4], 1, 0}...
{[-0.51279 -0.50269], 'Ttest(10)'}}
{{'t2', 'z'}, ...
{[3 4], 1, 0}...
{[-0.49693 -0.48727], 'Zscore()'}}
{{'t2', 'p'}, ...
{[3 4], 1, 0}...
{[0.61924 0.62607], 'Pval()'}}
{{'t2', 'left'}, ...
{[3 4], 1, 0}...
{[0.30962 0.31303], 'Pval()'}}
{{'t2', 'right'}, ...
{[3 4], 1, 0}...
{[0.69038 0.68697], 'Pval()'}}
{{'t2', 'both'}, ...
{[3 4], 1, 0}...
{[0.61924 0.62607], 'Pval()'}}
{{'t', ''}, ...
{[4], 1, 0}...
{[-0.65552 0.32687], 'Ttest(5)'}}
{{'t', 'z'}, ...
{[4], 1, 0}...
{[-0.61117 0.30942], 'Zscore()'}}
{{'t', 'p'}, ...
{[4], 1, 0}...
{[0.54109 0.75700], 'Pval()'}}
{{'t', 'left'}, ...
{[4], 1, 0}...
{[0.27055 0.62150], 'Pval()'}}
{{'t', 'right'}, ...
{[4], 1, 0}...
{[0.72945 0.37850], 'Pval()'}}
{{'t', 'both'}, ...
{[4], 1, 0}...
{[0.54109 0.75700], 'Pval()'}}};