%   CNSP-Workshop 2021: Decoding Tutorial
% 
%   This tutiorial loads 2 example datasets, a speech listening dataset
%   (LalorNatSpeech) and a cocktail party listening dataset 
%   (LalorCocktailParty), and demomnstrates how to train neural decoders 
%   that can reconstruct stimulus features, such as the speech envelope, 
%   from multi-channel EEG data, and use them to decode selective attention 
%   for cocktail party listening experiments.
%
%   Dependencies:
%      CNSP utils: https://cnsp-workshop.github.io/website/resources.html
%      EEGLAB: https://sccn.ucsd.edu/eeglab/index.php
%      mTRF-Toolbox: https://github.com/mickcrosse/mTRF-Toolbox

%   References:
%      [1] Crosse MC, Zuk NJ, Di Liberto GM, Nidiffer A, Molholm S, Lalor 
%          EC (2021) Linear Modeling of Neurophysiological Responses to 
%          Naturalistic Stimuli: Methodological Considerations for Applied 
%          Research. PsyArXiv.
%      [2] Crosse MC, Di Liberto GM, Bednar A, Lalor EC (2016) The
%          multivariate temporal response function (mTRF) toolbox: a MATLAB
%          toolbox for relating neural signals to continuous stimuli. Front
%          Hum Neurosci 10:604.

%   CNSP-Workshop 2021
%   https://cnsp-workshop.github.io/website/index.html
%   Author: Mick Crosse <mickcrosse@gmail.com>
%   Copyright 2021 - Giovanni Di Liberto
%                    Nathaniel Zuk
%                    Michael Crosse
%                    Aaron Nidiffer
%                    (see license file for details)

%% 1. Data ingestion

close all;
clear; clc;

% a. Set main path
cd('C:\Users\mickc\Data\CNSP-workshop2021_code')

% b. Add other directories to path
addpath CNSP_tutorial
addpath CNSP_tutorial\CNSP-tutorial2
addpath CNSP_tutorial\libs\cnsp_utils
addpath CNSP_tutorial\libs\cnsp_utils\cnd
addpath CNSP_tutorial\libs\eeglab
addpath CNSP_tutorial\libs\mTRF-Toolbox_v2\mtrf

% c. Load data
disp('Loading data...')
load('.\datasets\LalorNatSpeech\dataCND\dataStim.mat','stim');
load('.\datasets\LalorNatSpeech\dataCND\dataSub10.mat','eeg');

%% 8. DRYAD data preparation for Subject 10

% a. Use envelope feature
stim.data = stim.data(1,:);

% b. Crop EEG to match stim length
for i = 1:numel(eeg.data)
    eeg.data{i} = eeg.data{i}(1:length(stim.data{i}),:);
    eeg.extChan{1,1}.data{i} = eeg.extChan{1,1}.data{i}(1:length(stim.data{i}),:);
end

% c. Set up highpass filter
highpass_cutoff = 0.1;
highpass_order = 3;
hd_hpf = getHPFilt(eeg.fs,highpass_cutoff,highpass_order);

% d. Set up lowpass filter
lowpass_cutoff = 8;
lowpass_order = 3;
hd_lpf = getLPFilt(eeg.fs,lowpass_cutoff,lowpass_order);

% e. Filter EEG recording channels
disp('Filtering recording channels...')
eeg.data = cellfun(@(x) filtfilthd(hd_hpf,x),eeg.data,'UniformOutput',false);
eeg.data = cellfun(@(x) filtfilthd(hd_lpf,x),eeg.data,'UniformOutput',false);

% f. Filter EEG external channels
disp('Filtering external channels...')
eeg.extChan{1,1}.data = cellfun(@(x) filtfilthd(hd_hpf,x),eeg.extChan{1,1}.data,'UniformOutput',false);
eeg.extChan{1,1}.data = cellfun(@(x) filtfilthd(hd_lpf,x),eeg.extChan{1,1}.data,'UniformOutput',false);

% g. Downsample data
fs_new = 64;
disp('Downsampling data...')
eeg = cndDownsample(eeg,fs_new);
stim = cndDownsample(stim,fs_new);

% h. Interpolate bad channels
disp('Interpolating bad channels...')
if isfield(eeg,'chanlocs')
    for i = 1:numel(eeg.data)
        eeg.data{i} = removeBadChannels(eeg.data{i},eeg.chanlocs);
    end
end

% i. Re-reference EEG data
disp('Re-referencing EEG data...')
eeg = cndReref(eeg,'Avg');

% j. Normalize EEG data
disp('Normalizing data...')
eeg_data_mat = cell2mat(eeg.data');
eeg_std = std(eeg_data_mat(:));
eeg.data = cellfun(@(x) x/eeg_std,eeg.data,'UniformOutput',false);

%% 3. Cross-validation

% a. Define training and test sets
test_trials = 10:13; % 20% of data
stim_train = stim.data; 
eeg_train = eeg.data;
stim_train(test_trials) = [];
eeg_train(test_trials) = [];
stim_test = stim.data(test_trials);
eeg_test = eeg.data(test_trials);

% b. Model hyperparameters
Dir = -1;
tmin = 0;
tmax = 250;
lambda_vals = 10.^(-2:2:8);
nlambda = numel(lambda_vals);

% c. Run fast cross-validation
disp('Running cross-validation...')
cv = mTRFcrossval(stim_train,eeg_train,eeg.fs,Dir,tmin,tmax,lambda_vals,...
    'zeropad',0,'fast',1);

% d. Plot CV accuracy
figure(1)
subplot(2,2,1)
errorbar(1:nlambda,mean(cv.r),std(cv.r)/sqrt(numel(stim_train)),'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',-2:2:8), xlim([0,nlambda+1])
title('CV Accuracy')
xlabel('Regularization (1\times10^\lambda)')
ylabel('Correlation')
axis square, grid on

% e. Plot CV error
subplot(2,2,2)
errorbar(1:nlambda,mean(cv.err),std(cv.err)/sqrt(numel(stim_train)),'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',-2:2:8), xlim([0,nlambda+1])
title('CV Error')
xlabel('Regularization (1\times10^\lambda)')
ylabel('MSE')
axis square, grid on

%% 4. Model training

% a. Get optimal hyperparameters
[rmax,idx] = max(mean(cv.r));
lambda = lambda_vals(idx);

% b. Train model
disp('Training model...')
model = mTRFtrain(stim_train,eeg_train,eeg.fs,Dir,tmin,tmax,lambda,...
    'zeropad',0);

% c. Plot decoder weights
lim = max(max(abs(model.w(:,7:14))));
figure(2)
subplot(2,2,1)
topoplot(model.w(:,7),eeg.chanlocs,'maplimits',[-lim,lim],'whitebk','on')
title([num2str(model.t(7)),' ms'])
subplot(2,2,2)
topoplot(model.w(:,9),eeg.chanlocs,'maplimits',[-lim,lim],'whitebk','on')
title([num2str(model.t(9)),' ms'])
subplot(2,2,3)
topoplot(model.w(:,11),eeg.chanlocs,'maplimits',[-lim,lim],'whitebk','on')
title([num2str(model.t(11)),' ms'])
subplot(2,2,4)
topoplot(model.w(:,14),eeg.chanlocs,'maplimits',[-lim,lim],'whitebk','on')
title([num2str(model.t(14)),' ms'])

%% 5. Model testing

% a. Test model
disp('Testing model...')
[pred,test] = mTRFpredict(stim_test,eeg_test,model,'zeropad',0);

% b. Plot reconstruction
figure(1)
subplot(2,2,3)
plot((1:length(stim_test{1}))/eeg.fs,stim_test{1},'linewidth',2), hold on
plot((1:length(pred{1}))/eeg.fs,pred{1},'linewidth',2), hold off
xlim([0,10])
title('Reconstruction')
xlabel('Time (s)')
ylabel('Amplitude (a.u.)')
axis square, grid on
legend('Orig','Pred')

% c. Plot test correlation
subplot(2,2,4)
bar(1,rmax), hold on
bar(2,mean(test.r)), hold off
set(gca,'xtick',1:2,'xticklabel',{'Val.','Test'})
title('Model Performance')
xlabel('Dataset')
ylabel('Correlation')
axis square, grid on

%% 6. Single-lag stimulus reconstruction

% Run single-lag cross-validation
tmin = -250; tmax = 500;
[stats,t] = mTRFcrossval(stim_train,eeg_train,eeg.fs,Dir,tmin,tmax,lambda,...
    'type','single','zeropad',0);

% Compute mean and variance
macc = squeeze(mean(stats.r))'; vacc = squeeze(var(stats.r))';
merr = squeeze(mean(stats.err))'; verr = squeeze(var(stats.err))';

% Compute variance bound
num_folds = numel(stim_train);
xacc = [-fliplr(t),-t]; yacc = [fliplr(macc-sqrt(vacc/num_folds)),macc+sqrt(vacc/num_folds)];
xerr = [-fliplr(t),-t]; yerr = [fliplr(merr-sqrt(verr/num_folds)),merr+sqrt(verr/num_folds)];

% Plot accuracy
figure(3)
subplot(1,2,1), h = fill(xacc,yacc,'b','edgecolor','none'); hold on
set(h,'facealpha',0.2), xlim([tmin,tmax]), axis square, grid on
plot(-fliplr(t),fliplr(macc),'linewidth',2), hold off
title('Reconstruction Accuracy'), xlabel('Time lag (ms)'), ylabel('Correlation')

% Plot error
subplot(1,2,2)
h = fill(xerr,yerr,'b','edgecolor','none'); hold on
set(h,'facealpha',0.2), xlim([tmin,tmax]), axis square, grid on
plot(-fliplr(t),fliplr(merr),'linewidth',2), hold off
title('Reconstruction Error'), xlabel('Time lag (ms)'), ylabel('MSE')

%% 7. Model performance with less data

% a. Define training and test sets
test_trials = 10:13;
stim_train = stim.data(1:3); 
eeg_train = eeg.data(1:3);
stim_test = stim.data(test_trials);
eeg_test = eeg.data(test_trials);

% b. Model hyperparameters
Dir = -1;
tmin = 0;
tmax = 250;
lambda_vals = 10.^(-2:2:8);
nlambda = numel(lambda_vals);

% c. Run fast cross-validation
disp('Running cross-validation...')
cv = mTRFcrossval(stim_train,eeg_train,eeg.fs,Dir,tmin,tmax,lambda_vals,...
    'zeropad',0,'fast',1);

% d. Get optimal hyperparameters
[rmax,idx] = max(mean(cv.r));
lambda = lambda_vals(idx);

% e. Train model
disp('Training model...')
model = mTRFtrain(stim_train,eeg_train,eeg.fs,Dir,tmin,tmax,lambda,'zeropad',0);

% f. Test model
disp('Testing model...')
[pred,test] = mTRFpredict(stim_test,eeg_test,model,'zeropad',0);

% g. Plot CV accuracy
figure(4)
subplot(2,2,1)
errorbar(1:nlambda,mean(cv.r),std(cv.r)/sqrt(numel(stim_train)),'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',-2:2:8), xlim([0,nlambda+1])
title('CV Accuracy')
xlabel('Regularization (1\times10^\lambda)')
ylabel('Correlation')
axis square, grid on

% h. Plot CV error
subplot(2,2,2)
errorbar(1:nlambda,mean(cv.err),std(cv.err)/sqrt(numel(stim_train)),'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',-2:2:8), xlim([0,nlambda+1])
title('CV Error')
xlabel('Regularization (1\times10^\lambda)')
ylabel('MSE')
axis square, grid on

% i. Plot reconstruction
subplot(2,2,3)
plot((1:length(stim_test{1}))/eeg.fs,stim_test{1},'linewidth',2), hold on
plot((1:length(pred{1}))/eeg.fs,pred{1},'linewidth',2), hold off
xlim([0,10])
title('Reconstruction')
xlabel('Time (s)')
ylabel('Amplitude (a.u.)')
axis square, grid on
legend('Orig','Pred')

% j. Plot test correlation
subplot(2,2,4)
bar(1,rmax), hold on
bar(2,mean(test.r)), hold off
set(gca,'xtick',1:2,'xticklabel',{'Val.','Test'})
title('Model Performance')
xlabel('Dataset')
ylabel('Correlation')
axis square, grid on

%% 8. DRYAD data preparation for Subject 13

data = cell(1,2);
data{1} = eeg;
clear eeg

% a. Load data
disp('Loading data...')
load('.\datasets\LalorNatSpeech\dataCND\dataStim.mat','stim');
load('.\datasets\LalorNatSpeech\dataCND\dataSub13.mat','eeg');

% b. Use envelope feature
stim.data = stim.data(1,:);

% c. Crop EEG to match stim length
for i = 1:numel(eeg.data)
    eeg.data{i} = eeg.data{i}(1:length(stim.data{i}),:);
    eeg.extChan{1,1}.data{i} = eeg.extChan{1,1}.data{i}(1:length(stim.data{i}),:);
end

% d. Set up highpass filter
highpass_cutoff = 0.1;
highpass_order = 3;
hd_hpf = getHPFilt(eeg.fs,highpass_cutoff,highpass_order);

% e. Set up lowpass filter
lowpass_cutoff = 8;
lowpass_order = 3;
hd_lpf = getLPFilt(eeg.fs,lowpass_cutoff,lowpass_order);

% f. Filter EEG recording channels
disp('Filtering recording channels...')
eeg.data = cellfun(@(x) filtfilthd(hd_hpf,x),eeg.data,'UniformOutput',false);
eeg.data = cellfun(@(x) filtfilthd(hd_lpf,x),eeg.data,'UniformOutput',false);

% g. Filter EEG external channels
disp('Filtering external channels...')
eeg.extChan{1,1}.data = cellfun(@(x) filtfilthd(hd_hpf,x),eeg.extChan{1,1}.data,'UniformOutput',false);
eeg.extChan{1,1}.data = cellfun(@(x) filtfilthd(hd_lpf,x),eeg.extChan{1,1}.data,'UniformOutput',false);

% h. Downsample data
fs_new = 64;
disp('Downsampling data...')
eeg = cndDownsample(eeg,fs_new);
stim = cndDownsample(stim,fs_new);

% i. Interpolate bad channels
disp('Interpolating bad channels...')
if isfield(eeg,'chanlocs')
    for i = 1:numel(eeg.data)
        eeg.data{i} = removeBadChannels(eeg.data{i},eeg.chanlocs);
    end
end

% j. Re-reference EEG data
disp('Re-referencing EEG data...')
eeg = cndReref(eeg,'Avg');

% k. Normalize EEG data
disp('Normalizing data...')
eeg_data_mat = cell2mat(eeg.data');
eeg_std = std(eeg_data_mat(:));
eeg.data = cellfun(@(x) x/eeg_std,eeg.data,'UniformOutput',false);

data{2} = eeg;

%% 9. Subject-independent model

% a. Define training and test sets
test_trials = 10:13;
stim_train = stim.data(1:3); 
stim_train(4:6) = stim.data(1:3); 
eeg_train = data{1,1}.data(1:3);
eeg_train(4:6) = data{1,2}.data(1:3);
stim_test = stim.data(test_trials);
eeg_test = data{1,1}.data(test_trials);

% b. Model hyperparameters
Dir = -1;
tmin = 0;
tmax = 250;
lambda_vals = 10.^(-2:2:8);
nlambda = numel(lambda_vals);

% c. Run fast cross-validation
disp('Running cross-validation...')
cv = mTRFcrossval(stim_train,eeg_train,eeg.fs,Dir,tmin,tmax,lambda_vals,...
    'zeropad',0,'fast',1);

% d. Get optimal hyperparameters
[rmax,idx] = max(mean(cv.r));
lambda = lambda_vals(idx);

% e. Train model
disp('Training model...')
model = mTRFtrain(stim_train,eeg_train,eeg.fs,Dir,tmin,tmax,lambda,'zeropad',0);

% f. Test model
disp('Testing model...')
[pred,test] = mTRFpredict(stim_test,eeg_test,model,'zeropad',0);

% g. Plot CV accuracy
figure(5)
subplot(2,2,1)
errorbar(1:nlambda,mean(cv.r),std(cv.r)/sqrt(numel(stim_train)),'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',-2:2:8), xlim([0,nlambda+1])
title('CV Accuracy')
xlabel('Regularization (1\times10^\lambda)')
ylabel('Correlation')
axis square, grid on

% h. Plot CV error
subplot(2,2,2)
errorbar(1:nlambda,mean(cv.err),std(cv.err)/sqrt(numel(stim_train)),'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',-2:2:8), xlim([0,nlambda+1])
title('CV Error')
xlabel('Regularization (1\times10^\lambda)')
ylabel('MSE')
axis square, grid on

% i. Plot reconstruction
subplot(2,2,3)
plot((1:length(stim_test{1}))/eeg.fs,stim_test{1},'linewidth',2), hold on
plot((1:length(pred{1}))/eeg.fs,pred{1},'linewidth',2), hold off
xlim([0,10])
title('Reconstruction')
xlabel('Time (s)')
ylabel('Amplitude (a.u.)')
axis square, grid on
legend('Orig','Pred')

% j. Plot test correlation
subplot(2,2,4)
bar(1,rmax), hold on
bar(2,mean(test.r)), hold off
set(gca,'xtick',1:2,'xticklabel',{'Val.','Test'})
title('Model Performance')
xlabel('Dataset')
ylabel('Correlation')
axis square, grid on

%% 10. Attention dataset

clear; clc;

% a. Load data
disp('Loading data...')
load('.\datasets\LalorCocktailParty\dataCND\dataStim.mat','stim','stim2');
load('.\datasets\LalorCocktailParty\dataCND\dataSub5.mat','eeg');

%% 11. DRYAD data preparation for Subject 4

% a. Use envelope feature
stim.data = stim.data(1,:);
stim2.data = stim2.data(1,:);

% b. Crop EEG to match stim length
for i = 1:numel(eeg.data)
    if length(stim.data{i}) < length(eeg.data{i})
        eeg.data{i} = eeg.data{i}(1:length(stim.data{i}),:);
        eeg.extChan{1,1}.data{i} = eeg.extChan{1,1}.data{i}(1:length(stim.data{i}),:);
    else
        stim.data{i} = stim.data{i}(1:length(eeg.data{i}),:);
        stim2.data{i} = stim2.data{i}(1:length(eeg.data{i}),:);
    end
end

% c. Set up highpass filter
highpass_cutoff = 0.1;
highpass_order = 3;
hd_hpf = getHPFilt(eeg.fs,highpass_cutoff,highpass_order);

% d. Set up lowpass filter
lowpass_cutoff = 8;
lowpass_order = 3;
hd_lpf = getLPFilt(eeg.fs,lowpass_cutoff,lowpass_order);

% e. Filter EEG recording channels
disp('Filtering recording channels...')
eeg.data = cellfun(@(x) filtfilthd(hd_hpf,x),eeg.data,'UniformOutput',false);
eeg.data = cellfun(@(x) filtfilthd(hd_lpf,x),eeg.data,'UniformOutput',false);

% f. Filter EEG external channels
disp('Filtering external channels...')
eeg.extChan{1,1}.data = cellfun(@(x) filtfilthd(hd_hpf,x),eeg.extChan{1,1}.data,'UniformOutput',false);
eeg.extChan{1,1}.data = cellfun(@(x) filtfilthd(hd_lpf,x),eeg.extChan{1,1}.data,'UniformOutput',false);

% g. Downsample data
fs_new = 64;
disp('Downsampling data...')
eeg = cndDownsample(eeg,fs_new);
stim = cndDownsample(stim,fs_new);
stim2 = cndDownsample(stim2,fs_new);

% h. Interpolate bad channels
disp('Interpolating bad channels...')
if isfield(eeg,'chanlocs')
    for i = 1:numel(eeg.data)
        eeg.data{i} = removeBadChannels(eeg.data{i},eeg.chanlocs);
    end
end

% i. Re-reference EEG data
disp('Re-referencing EEG data...')
eeg = cndReref(eeg,'Avg');

% j. Normalize EEG data
disp('Normalizing data...')
eeg_data_mat = cell2mat(eeg.data');
eeg_std = std(eeg_data_mat(:));
eeg.data = cellfun(@(x) x/eeg_std,eeg.data,'UniformOutput',false);

%% 12. Decode attention

% a. Define training and test sets
stim_train = stim.data; 
stim2_train = stim2.data; 
eeg_train = eeg.data;

% b. Model hyperparameters
Dir = -1;
tmin = 0;
tmax = 250;
lambda_vals = 10.^(-2:2:8);
nlambda = numel(lambda_vals);

% c. Run fast cross-validation
disp('Running cross-validation...')
[cv,cv1,cv2] = mTRFattncrossval(stim_train,stim2_train,eeg_train,eeg.fs,Dir,tmin,tmax,lambda_vals,...
    'zeropad',0,'fast',1);

% d. Plot reconstruction accuracy
figure(6)
subplot(2,2,1), hold on
errorbar(1:nlambda,mean(cv1.r),std(cv1.r)/sqrt(numel(stim_train)),'linewidth',2)
errorbar(1:nlambda,mean(cv2.r),std(cv2.r)/sqrt(numel(stim_train)),'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',-2:2:8), xlim([0,nlambda+1])
title('Reconstruction Acc.')
xlabel('Regularization (1\times10^\lambda)')
ylabel('Correlation')
axis square, grid on, box on
l = legend('Att.','Unatt.','location','west'); set(l,'box','off')

% e. Plot reconstruction error
subplot(2,2,2), hold on
errorbar(1:nlambda,mean(cv1.err),std(cv1.err)/sqrt(numel(stim_train)),'linewidth',2)
errorbar(1:nlambda,mean(cv2.err),std(cv2.err)/sqrt(numel(stim_train)),'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',-2:2:8), xlim([0,nlambda+1])
title('Reconstruction Err.')
xlabel('Regularization (1\times10^\lambda)')
ylabel('MSE')
axis square, grid on, box on

% f. Decoding accuracy
subplot(2,2,3)
plot(1:nlambda,cv.acc,'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',-2:2:8), xlim([0,nlambda+1])
title('Decoding Accuracy')
xlabel('Regularization (1\times10^\lambda)')
ylabel('Accuracy')
axis square, grid on

% g. Plot modulation index
subplot(2,2,4)
plot(1:nlambda,cv.d,'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',-2:2:8), xlim([0,nlambda+1])
title('Modulation Index')
xlabel('Regularization (1\times10^\lambda)')
ylabel('\itd''')
axis square, grid on
