run classify lda skl

%% odd-even classification with LDA classifier
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #

%% Define data
config = cosmo_config();
data_path = fullfile(config.tutorial_data_path, 'ak6', 's01');

% Load the dataset with VT mask
ds = cosmo_fmri_dataset([data_path '/glm_T_stats_perrun.nii'], ...
                        'mask', [data_path '/vt_mask.nii']);

% remove constant features
ds = cosmo_remove_useless_data(ds);

%% set sample attributes

ds.sa.targets = repmat((1:6)', 10, 1);
ds.sa.chunks = floor(((1:60) - 1) / 6)' + 1;

% Add labels as sample attributes
classes = {'monkey', 'lemur', 'mallard', 'warbler', 'ladybug', 'lunamoth'};
ds.sa.labels = repmat(classes, 1, 10)';

%% Part 1: bird classification; train on even runs, test on odd runs
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% slice into odd and even runs using ds.sa.chunks attribute, and
% store in new dataset structs called 'ds_even' and 'ds_odd'.
% (hint: use the 'mod' function (remainder after division) to see which
% chunks are even or odd)
%%%% >>> Your code here <<< %%%%

%% discriminate between mallards and warblers
categories = {'mallard', 'warbler'};

% select samples where .sa.labels match on of the categories
% for the even and odd runs separately. Slice the dataset twice and store
% the result in 'ds_even_birds' and 'ds_odd_birds'
% (use cosmo_match with .sa.labels and categories to define a mask,
% then cosmo_slice to select the data)
%%%% >>> Your code here <<< %%%%

% show the data
fprintf('Even data:\n');
cosmo_disp(ds_even_birds);

fprintf('Odd data:\n');
cosmo_disp(ds_odd_birds);

% train on even, test on odd
%
% Use cosmo_classify_lda to get predicted targets for the odd runs when
% training on the even runs, and assign these predictions to
% a variable 'test_pred'.
% (hint: use .samples and .sa.targets from ds_even_birds, and
%        use .samples from ds_odd_birds)
%%%% >>> Your code here <<< %%%%

% Assign the real targets of the odd runs to a variable 'test_targets'
%%%% >>> Your code here <<< %%%%

% show real and predicted labels
fprintf('\ntarget predicted\n');
disp([test_targets test_pred]);

% compare the predicted labels for the odd
% runs with the actual targets to compute the accuracy. Store the accuracy
% in a variable 'accuracy'.
%%%% >>> Your code here <<< %%%%
fprintf('\nLDA birds even-odd: accuracy %.3f\n', accuracy);

% compare with naive bayes classification
% (hint: do classification as above, but use cosmo_classify_naive_bayes)
%%%% >>> Your code here <<< %%%%
fprintf('\nNaive Bayes birds even-odd: accuracy %.3f\n', accuracy);

%% Part 2: all categories; train/test on even/odd runs and vice versa
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% This is as above, but without slicing to get the samples with bird
% species. In other words, just use 'ds_even' and 'ds_odd'
%
% First, train on even, test on odd
%%%% >>> Your code here <<< %%%%
fprintf('\nLDA all categories even-odd: accuracy %.3f\n', accuracy);

% Now train on odd, test on even
%%%% >>> Your code here <<< %%%%
fprintf('\nLDA all categories odd-even: accuracy %.3f\n', accuracy);

%% Part 3: build confusion matrix
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% manually build the confusion matrix for the six categories

% first, allocate space for the confusion matrix
nclasses = numel(classes); % should be 6
confusion_matrix = zeros(nclasses); % 6x6 matrix

% sanity check to ensure targets are in range 1..6
assert(isequal(unique(test_targets), (1:6)'));

% in confusion matrix, the i-th row and j-th column should contain
% the number of times that a sample with test_targets==i was predicted as
% test_pred==j. Use a nested for-loop (a for-loop in a for-loop) to count
% this for all combinations of i (1 to 6) and j (1 to 6)
%%%% >>> Your code here <<< %%%%

% CoSMoMVPA can generate the confusion matrix using cosmo_confusion_matrix;
% the check below ensures that your solution matches the one produced by
% CoSMoMVPA
confusion_matrix_alt = cosmo_confusion_matrix(test_targets, test_pred);
if ~isequal(confusion_matrix, confusion_matrix_alt)
    error('your confusion matrix does not match the expected output');
end

figure;
imagesc(confusion_matrix, [0 5]);
title('confusion matrix');
set(gca, 'XTick', 1:nclasses, 'XTickLabel', classes);
set(gca, 'YTick', 1:nclasses, 'YTickLabel', classes);
ylabel('target');
xlabel('predicted');
colorbar;