% Tutorial to demonstrate shuffling (multivariate)
%%% How do we tell if the effect of a feature is significant?

addpath('../libs/mTRF-Toolbox_v2/mtrf');
addpath(genpath('../libs/cnsp_utils'));
addpath('../libs/NoiseTools');

% Reseed the random number generator
rng('shuffle');

%% Parameters
tmin = 0;
tmax = 350;
lambdas = [0 10.^(-2:6)];
sbj = 6; % subject id for the data to use
downFs = 64; % sampling rate of the EEG data and stimulus before modeling
bandpassFilterRange = [1 8];
reRefType = 'Mastoids'; % for referencing to mastoids
max_tr = 20; % maximum trial to use

%% Load stimulus and EEG data
% Get multivariate features (music amplitude and expectation)
% dataStimNew includes IDyOM expectation and surprisal vectors (see
% addIDyOMvects.m)
load('../datasets/diliBach/dataCND/dataStim.mat'); % loads the 'stim' structure
% concatenate features for each trial into one matrix, so features are along
% columns
ntr = size(stim.data,2); % number of trials
catStim = cell(1,ntr);
for n = 1:ntr
    catStim{n} = cell2mat(stim.data(:,n)');
    % also, normalize all of the sounds so the root-mean-square of all
    % datapoints in the stimulus vectors is equal to 1
    for f = 1:size(catStim{n},2)
        catStim{n}(:,f) = catStim{n}(:,f)/rms(catStim{n}(:,f));
    end
end
stim.data = catStim;
% keep track of the acoustic vectors (the A model) and the melodic
% expectation vectors (the M model)
% (see stim.names for the parameter names)
acoust_idx = 1:3;
melody_idx = 4:7;

% Get EEG data
load(sprintf('../datasets/diliBach/dataCND/dataSub%d.mat',sbj)); % loads the 'eeg' structure

%% Preprocessing

% Filter data
% Filtering - LPF (low-pass filter)
if bandpassFilterRange(2) > 0
    hd = getLPFilt(eeg.fs,bandpassFilterRange(2));
    
    % A little coding trick - for loop vs cellfun
    eeg.data = cellfun(@(x) filtfilthd(hd,x),eeg.data,'UniformOutput',false);
    
    % Filtering external channels
    if isfield(eeg,'extChan')
        for extIdx = 1:length(eeg.extChan)
            eeg.extChan{extIdx}.data = cellfun(@(x) filtfilthd(hd,x),eeg.extChan{extIdx}.data,'UniformOutput',false);
        end
    end
    
    eeg = cndNewOp(eeg,'LPF');
end

% Downsampling EEG and external channels
if downFs < eeg.fs
    eeg = cndDownsample(eeg,downFs);
end

% Filtering - HPF (high-pass filter)
if bandpassFilterRange(1) > 0 
    hd = getHPFilt(eeg.fs,bandpassFilterRange(1));
    
    % Filtering EEG data
    eeg.data = cellfun(@(x) filtfilthd(hd,x),eeg.data,'UniformOutput',false);
    
    % Filtering external channels
    if isfield(eeg,'extChan')
        for extIdx = 1:length(eeg.extChan)
            eeg.extChan{extIdx}.data = cellfun(@(x) filtfilthd(hd,x),eeg.extChan{extIdx}.data,'UniformOutput',false);
        end  
    end
    
    eeg = cndNewOp(eeg,'HPF');
end

% Re-referencing EEG data
eeg = cndReref(eeg,reRefType);

% Remove zero-padding from the start of the EEG, then shorten the length of
% the EEG data so that it is the same as the stimulus
for tr = 1:ntr
    % remove zero padding
    stim_start = eeg.paddingStartSample;
    stim_len = size(stim.data{tr},1);
    eeg.data{tr} = eeg.data{tr}(stim_start+(1:stim_len),:);
    % and zero-center all of the channels
    eeg.data{tr} = detrend(eeg.data{tr},0);
end

%%% Channel Cz is 48, Fz is 38
nchan = size(eeg.data{1},2); % number of channels

%% Part 1: Nested cross-validation and testing of AM
% Leave out trials after max_tr
stim.data = stim.data(1:max_tr);
eeg.data = eeg.data(1:max_tr);
ntr = max_tr;

% Fit an AM model.
% Test on each left-out trial
mdl_tm = tic;
r_AM = NaN(ntr,1);
opt_lmb_AM = NaN(ntr,1);
mdl_AM = cell(ntr,1);
for test_tr = 1:ntr
    % test_tr is the testing trial
    % train on all others
    train_trs = setxor(1:ntr,test_tr);
    % Use cross-validation for the AM model
    stats = mTRFcrossval(stim.data(train_trs),eeg.data(train_trs),downFs,1,tmin,tmax,lambdas);
    % identify the optimal lambda parameter
    mn_r = squeeze(mean(mean(stats.r,3),1)); % average cross-validation r values over channels, then folds
    opt_lmb_AM(test_tr) = lambdas(mn_r==max(mn_r));
    % fit the AM model with the optimal lambda
    mdl_AM{test_tr} = mTRFtrain(stim.data(train_trs),eeg.data(train_trs),downFs,1,tmin,tmax,opt_lmb_AM(test_tr));
    % test on the left out trial
    [~,stats_test] = mTRFpredict(stim.data{test_tr},eeg.data{test_tr},mdl_AM{test_tr});
    r_AM(test_tr) = mean(stats_test.r,2); % average correlation values across channels
end
mdl_elapsed_time = toc(mdl_tm);
fprintf('Completed nested cross-validation and testing of AM @ %.3f s\n',mdl_elapsed_time);

% Average the model weights
ndly = length(mdl_AM{1}.t);
nfeat = size(stim.data{1},2); % number of stimulus features in the AM model
avgmdl = NaN(nfeat,ndly,nchan,ntr);
for tr = 1:ntr
    avgmdl(:,:,:,tr) = mdl_AM{tr}.w;
end
avgmdl = mean(avgmdl,4);
% Plot the average model weights across trials
figure
plot(mdl_AM{1}.t,avgmdl(:,:,48)); % use channel Cz = 48
xlabel('Delay (ms)');
ylabel('Average model weight at Cz');
legend(stim.names);

%% Part 1 continued: How well does the AM model perform? Calculate the model's chance performance
% Is the performance of the model significantly above chance? Random circular
% shifting 
null = mTRFpermute(stim.data,eeg.data,downFs,1,'circshift',tmin,tmax,mode(opt_lmb_AM),'nperm',100);

% Plot true relative to null
figure
subplot(1,2,1)
hold on
plot(0,r_AM,'b.','MarkerSize',16);
mnullr = mean(null.nullr,2);
mdnull = median(mean(null.nullr,2));
lqnull = quantile(mean(null.nullr,2),0.05);
uqnull = quantile(mean(null.nullr,2),0.95);
errorbar(1,mdnull,mdnull-lqnull,uqnull-mdnull,'ko','LineWidth',1.5,'MarkerSize',12);
set(gca,'XLim',[-1 2],'XTick',[0 1],'XTickLabel',{'True AM','Null AM (Md +/- 5-95 quantiles)'});
ylabel('Prediction accuracy, averaged across channels (r)');
% Calculate the p-value for the average across trials
pval_avg = sum(mnullr>mean(r_AM))/length(mnullr); % length(mnullr) = nperm
title(sprintf('P-value of average (one-tailed) = %.3g',pval_avg))

% Plot the z-scored values
subplot(1,2,2)
plot((r_AM-mean(mnullr))/std(mnullr),'b.','MarkerSize',16);
xlabel('Trial number');
ylabel('Z-scored prediction accuracy relative to null distr');

% Display the d-prime value relative to null
dpr = (mean(r_AM)-mean(mnullr))/sqrt(0.5*var(r_AM) + 0.5*var(mnullr));
title(sprintf('d-prime = %.3f',dpr));

%% Exercise 1: How well does the just-onset model perform?
% Hint: You can use the following line of code to get specific features, but you
% need to specify which indexes to use with 'onset_idx':
% ** stim_Ao = cellfun(@(x) x(:,onset_idx),stim.data,'UniformOutput',false);
% then replace stim_Ao with stim.data in the code you write

% Get just the onset features
stim_Ao = cellfun(@(x) x(:,2),stim.data,'UniformOutput',false);

% Fit an Ao model.
% Test on each left-out trial
mdl_tm = tic;
r_Ao = NaN(ntr,1);
opt_lmb_Ao = NaN(ntr,1);
for test_tr = 1:ntr
    % test_tr is the testing trial
    % train on all others
    train_trs = setxor(1:ntr,test_tr);
    % Use crossvalidation on the onset-only model
    stats = mTRFcrossval(stim_Ao(train_trs),eeg.data(train_trs),downFs,1,tmin,tmax,lambdas);
    % identify the optimal lambda parameter
    mn_r = squeeze(mean(mean(stats.r,3),1)); % average cross-validation r values over channels, then folds
    opt_lmb_Ao(test_tr) = lambdas(mn_r==max(mn_r));
    % fit the model with the optimal lambda
    mdl_Ao = mTRFtrain(stim_Ao(train_trs),eeg.data(train_trs),downFs,1,tmin,tmax,opt_lmb_Ao(test_tr));
    % test on the left out trial
    [~,stats_test] = mTRFpredict(stim_Ao{test_tr},eeg.data{test_tr},mdl_Ao);
    r_Ao(test_tr) = mean(stats_test.r,2); % average correlation values across channels
end
fprintf('Completed nested cross-validation and testing of Ao @ %.3f s\n',toc(mdl_tm));

% Is the performance of the model significantly above chance? Random circular
% shifting and refitting to get the null distribution
null = mTRFpermute(stim_Ao,eeg.data,downFs,1,'circshift',tmin,tmax,mode(opt_lmb_Ao),'nperm',100);

% Plot the true prediction accuracies for the onset-only model and the null
% distribution
figure
subplot(1,2,1)
hold on
plot(0,r_Ao,'b.-','MarkerSize',16);
mnullr_Ao = mean(null.nullr,2);
mdnull = median(mean(null.nullr,2));
lqnull = quantile(mean(null.nullr,2),0.05);
uqnull = quantile(mean(null.nullr,2),0.95);
errorbar(1,mdnull,mdnull-lqnull,uqnull-mdnull,'ko','LineWidth',1.5,'MarkerSize',12);
set(gca,'XLim',[-1 2],'XTick',[0 1],'XTickLabel',{'True Ao','Null Ao (Md +/- 5-95 quantiles)'});
ylabel('Prediction accuracy, averaged across channels (r)');
% Calculate the p-value for the average across trials
pval_avg = sum(mnullr_Ao>mean(r_Ao))/length(mnullr_Ao); % length(mnullr) = nperm
title(sprintf('P-value of average (one-tailed) = %.3g',pval_avg))

% Plot the z-scored values
subplot(1,2,2)
plot((r_Ao-mean(mnullr_Ao))/std(mnullr_Ao),'b.','MarkerSize',16);
xlabel('Trial number');
ylabel('Z-scored prediction accuracy relative to null distr');

% Display the d-prime value relative to null
dpr_Ao = (mean(r_Ao)-mean(mnullr_Ao))/sqrt(0.5*var(r_Ao) + 0.5*var(mnullr_Ao));
title(sprintf('d-prime = %.3f',dpr_Ao));

%% Part 2: Examine the significance of specific features
% We will examine how much the melodic expectation features contribute to
% the model

% In order to compute an A model, get a cell array of just the acoustic
% features
A = cellfun(@(x) x(:,acoust_idx),stim.data,'UniformOutput',false);

% Create a new set of AM stimuli, where the non-zero values of M are
% randomly shifted on each trial
[AMshuf,shft] = randshift_discrete(stim.data,melody_idx);

redmdltm = tic;
r_A = NaN(ntr,1);
r_AMshuf = NaN(ntr,1);
opt_lmb_A = NaN(ntr,1);
opt_lmb_AMshuf = NaN(ntr,1);
for test_tr = 1:ntr
    % train on all others
    train_trs = setxor(1:ntr,test_tr);
    % A model
    stats = mTRFcrossval(A(train_trs),eeg.data(train_trs),downFs,1,tmin,tmax,lambdas);
    % identify the optimal lambda parameter
    mn_r = squeeze(mean(mean(stats.r,3),1)); % average cross-validation r values over channels, then folds
    opt_lmb_A(test_tr) = lambdas(mn_r==max(mn_r));
    % fit the model with the optimal lambda
    mdl_A = mTRFtrain(A(train_trs),eeg.data(train_trs),downFs,1,tmin,tmax,opt_lmb_A(test_tr));
    % test on the left out trial
    [~,stats_test] = mTRFpredict(A{test_tr},eeg.data{test_tr},mdl_A);
    r_A(test_tr) = mean(stats_test.r,2); % average correlation values across channels

    % AMshuf model
    stats = mTRFcrossval(AMshuf(train_trs),eeg.data(train_trs),downFs,1,tmin,tmax,lambdas);
    % identify the optimal lambda parameter
    mn_r = squeeze(mean(mean(stats.r,3),1)); % average cross-validation r values over channels, then folds
    opt_lmb_AMshuf(test_tr) = lambdas(mn_r==max(mn_r));
    % fit the model with the optimal lambda
    mdl_AMshuf = mTRFtrain(AMshuf(train_trs),eeg.data(train_trs),downFs,1,tmin,tmax,opt_lmb_AMshuf(test_tr));
    % test on the left out trial
    [~,stats_test] = mTRFpredict(AMshuf{test_tr},eeg.data{test_tr},mdl_AMshuf);
    r_AMshuf(test_tr) = mean(stats_test.r,2); % average correlation values across channels
end
red_elapsed_time = toc(redmdltm);
fprintf('Reduced models calculated @ %.3f s\n',red_elapsed_time);

% Plot the comparison to the reduced models
figure
hold on
plot([-1 2],[0 0],'k--')
% Plot the difference between AM and the reduced / shuffled models,
plot(0:1,[r_AM-r_A r_AM-r_AMshuf],'b.','MarkerSize',16);
set(gca,'XLim',[-1 2],'XTick',[0 1],'XTickLabel',{'AM-A','AM-AMshuf'});
ylabel('\Deltar');
% calculate the significance of the pairwise difference (signed rank test)
p_adiff = signrank(r_AM,r_A);
p_ashuff = signrank(r_AM,r_AMshuf);
title(sprintf('Signed-rank test: p_{AM-A} = %.3f, p_{AM-AMshuf} = %.3f',p_adiff,p_ashuff));

%% Exercise 2: Are the absolute pitch values important for the acoustic model?
% Hint: Take a look at stim.names to identify which of the features of the
% acoustic model is absolute pitch

% AnoP = the acoustic model without pitch features
AnoP = cellfun(@(x) x(:,[1 2]),stim.data,'UniformOutput',false); % feature 3 is pitch
% AshufP = the acoustic model but we have randomly shuffled the pitch
% values
AshufP = randshift_discrete(A,3); % randomly shift the pitch features

r_AnoP = NaN(ntr,1);
r_AshufP = NaN(ntr,1);
opt_lmb_AnoP = NaN(ntr,1);
opt_lmb_AshufP = NaN(ntr,1);
for test_tr = 1:ntr
    % train on all others
    train_trs = setxor(1:ntr,test_tr);
    % AnoP model
    stats = mTRFcrossval(AnoP(train_trs),eeg.data(train_trs),downFs,1,tmin,tmax,lambdas);
    % identify the optimal lambda parameter
    mn_r = squeeze(mean(mean(stats.r,3),1)); % average cross-validation r values over channels, then folds
    opt_lmb_AnoP(test_tr) = lambdas(mn_r==max(mn_r));
    % fit the model with the optimal lambda
    mdl_AnoP = mTRFtrain(AnoP(train_trs),eeg.data(train_trs),downFs,1,tmin,tmax,opt_lmb_AnoP(test_tr));
    % test on the left out trial
    [~,stats_test] = mTRFpredict(AnoP{test_tr},eeg.data{test_tr},mdl_AnoP);
    r_AnoP(test_tr) = mean(stats_test.r,2); % average correlation values across channels

    % AshufP model
    stats = mTRFcrossval(AshufP(train_trs),eeg.data(train_trs),downFs,1,tmin,tmax,lambdas);
    % identify the optimal lambda parameter
    mn_r = squeeze(mean(mean(stats.r,3),1)); % average cross-validation r values over channels, then folds
    opt_lmb_AshufP(test_tr) = lambdas(mn_r==max(mn_r));
    % fit the model with the optimal lambda
    mdl_AshufP = mTRFtrain(AshufP(train_trs),eeg.data(train_trs),downFs,1,tmin,tmax,opt_lmb_AshufP(test_tr));
    % test on the left out trial
    [~,stats_test] = mTRFpredict(AshufP{test_tr},eeg.data{test_tr},mdl_AshufP);
    r_AshufP(test_tr) = mean(stats_test.r,2); % average correlation values across channels
end

% Plot the comparison to the reduced models
figure
hold on
plot([-1 2],[0 0],'k--')
plot(0:1,[r_A-r_AnoP r_A-r_AshufP],'b.','MarkerSize',16);
set(gca,'XLim',[-1 2],'XTick',[0 1],'XTickLabel',{'A-AnoP','A-AshufP'});
ylabel('\Deltar');
% calculate the significance of the pairwise difference (signed rank test)
p_anop = signrank(r_A,r_AnoP);
p_ashufp = signrank(r_A,r_AshufP);
title(sprintf('Signed-rank test: p_{A-AnoP} = %.3f, p_{A-AshufP} = %.3f',p_anop,p_ashufp));