odd-even classification with LDA classifier
- For CoSMoMVPA's copyright information and license terms, #
- see the COPYING file distributed with CoSMoMVPA. #
Contents
Define data
config = cosmo_config();
data_path = fullfile(config.tutorial_data_path, 'ak6', 's01');
ds = cosmo_fmri_dataset([data_path '/glm_T_stats_perrun.nii'], ...
'mask', [data_path '/vt_mask.nii']);
ds = cosmo_remove_useless_data(ds);
set sample attributes
ds.sa.targets = repmat((1:6)', 10, 1);
ds.sa.chunks = floor(((1:60) - 1) / 6)' + 1;
classes = {'monkey', 'lemur', 'mallard', 'warbler', 'ladybug', 'lunamoth'};
ds.sa.labels = repmat(classes, 1, 10)';
Part 1: bird classification; train on even runs, test on odd runs
even_msk = mod(ds.sa.chunks, 2) == 0;
odd_msk = mod(ds.sa.chunks, 2) == 1;
ds_even = cosmo_slice(ds, even_msk);
ds_odd = cosmo_slice(ds, odd_msk);
discriminate between mallards and warblers
categories = {'mallard', 'warbler'};
msk_even_birds = cosmo_match(ds_even.sa.labels, categories);
ds_even_birds = cosmo_slice(ds_even, msk_even_birds);
msk_odd_birds = cosmo_match(ds_odd.sa.labels, categories);
ds_odd_birds = cosmo_slice(ds_odd, msk_odd_birds);
fprintf('Even data:\n');
cosmo_disp(ds_even_birds);
fprintf('Odd data:\n');
cosmo_disp(ds_odd_birds);
train_samples = ds_even_birds.samples;
train_targets = ds_even_birds.sa.targets;
test_samples = ds_odd_birds.samples;
test_pred = cosmo_classify_lda(train_samples, train_targets, ...
test_samples);
test_targets = ds_odd_birds.sa.targets;
fprintf('\ntarget predicted\n');
disp([test_targets test_pred]);
accuracy = mean(test_pred == test_targets);
fprintf('\nLDA birds even-odd: accuracy %.3f\n', accuracy);
test_pred_nb = cosmo_classify_naive_bayes(train_samples, train_targets, ...
test_samples);
test_targets = ds_odd_birds.sa.targets;
accuracy = mean(test_pred_nb == test_targets);
fprintf('\nNaive Bayes birds even-odd: accuracy %.3f\n', accuracy);
Even data:
.a
.vol
.mat
[ -3 0 0 121
0 3 0 -114
0 0 3 -11.1
0 0 0 1 ]
.xform
'scanner_anat'
.dim
[ 80 80 43 ]
.fdim
.labels
{ 'i'
'j'
'k' }
.values
{ [ 1 2 3 ... 78 79 80 ]@1x80
[ 1 2 3 ... 78 79 80 ]@1x80
[ 1 2 3 ... 41 42 43 ]@1x43 }
.sa
.targets
[ 3
4
3
:
4
3
4 ]@10x1
.chunks
[ 2
2
4
:
8
10
10 ]@10x1
.labels
{ 'mallard'
'warbler'
'mallard'
:
'warbler'
'mallard'
'warbler' }@10x1
.samples
[ 2.9 1.72 2.76 ... 4.16 2.56 2.95
1.89 1.37 3.05 ... 4.38 4.46 3.86
1.26 2.41 2.87 ... 1.96 3.74 3.34
: : : : : :
2.09 1.53 1.98 ... 3.21 2.94 4.08
2.13 1.22 2.12 ... 2.3 3.14 1.99
1.5 1.6 2.07 ... 0.491 1.1 1.78 ]@10x384
.fa
.i
[ 32 33 34 ... 28 29 29 ]@1x384
.j
[ 24 24 24 ... 25 25 26 ]@1x384
.k
[ 3 3 3 ... 9 9 9 ]@1x384
Odd data:
.a
.vol
.mat
[ -3 0 0 121
0 3 0 -114
0 0 3 -11.1
0 0 0 1 ]
.xform
'scanner_anat'
.dim
[ 80 80 43 ]
.fdim
.labels
{ 'i'
'j'
'k' }
.values
{ [ 1 2 3 ... 78 79 80 ]@1x80
[ 1 2 3 ... 78 79 80 ]@1x80
[ 1 2 3 ... 41 42 43 ]@1x43 }
.sa
.targets
[ 3
4
3
:
4
3
4 ]@10x1
.chunks
[ 1
1
3
:
7
9
9 ]@10x1
.labels
{ 'mallard'
'warbler'
'mallard'
:
'warbler'
'mallard'
'warbler' }@10x1
.samples
[ 1.3 0.646 0.591 ... 1.51 1.75 3.08
3.08 3.05 4.08 ... 1.88 2.26 2.92
2.13 2.69 2.99 ... 2.03 2.87 1.74
: : : : : :
1.47 0.507 1.87 ... 2.59 3.09 3.51
2.24 2.37 3.27 ... 4.61 2.13 4.32
1.36 1.31 1.52 ... 2.95 2.61 1.18 ]@10x384
.fa
.i
[ 32 33 34 ... 28 29 29 ]@1x384
.j
[ 24 24 24 ... 25 25 26 ]@1x384
.k
[ 3 3 3 ... 9 9 9 ]@1x384
target predicted
3 4
4 4
3 3
4 4
3 3
4 4
3 3
4 4
3 3
4 4
LDA birds even-odd: accuracy 0.900
Naive Bayes birds even-odd: accuracy 0.700
Part 2: all categories; train/test on even/odd runs and vice versa
train_samples = ds_even.samples;
train_targets = ds_even.sa.targets;
test_samples = ds_odd.samples;
test_pred = cosmo_classify_lda(train_samples, train_targets, test_samples);
test_targets = ds_odd.sa.targets;
accuracy = mean(test_pred == test_targets);
fprintf('\nLDA all categories even-odd: accuracy %.3f\n', accuracy);
train_samples = ds_odd.samples;
train_targets = ds_odd.sa.targets;
test_samples = ds_even.samples;
test_pred = cosmo_classify_lda(train_samples, train_targets, test_samples);
test_targets = ds_even.sa.targets;
accuracy = mean(test_pred == test_targets);
fprintf('\nLDA all categories odd-even: accuracy %.3f\n', accuracy);
LDA all categories even-odd: accuracy 0.767
LDA all categories odd-even: accuracy 0.733
Part 3: build confusion matrix
nclasses = numel(classes);
confusion_matrix = zeros(nclasses);
assert(isequal(unique(test_targets), (1:6)'));
for predicted = 1:nclasses
for target = 1:nclasses
match_mask = test_pred == predicted & test_targets == target;
match_count = sum(match_mask);
confusion_matrix(target, predicted) = match_count;
end
end
confusion_matrix_alt = cosmo_confusion_matrix(test_targets, test_pred);
if ~isequal(confusion_matrix, confusion_matrix_alt)
error('your confusion matrix does not match the expected output');
end
figure;
imagesc(confusion_matrix, [0 5]);
title('confusion matrix');
set(gca, 'XTick', 1:nclasses, 'XTickLabel', classes);
set(gca, 'YTick', 1:nclasses, 'YTickLabel', classes);
ylabel('target');
xlabel('predicted');
colorbar;