% CNSP Tutorial - Stimulus feature extraction
clear; clc;

folderRoot = pwd;
folderCND = [folderRoot '\dataCND\'];


% Some parameters
lambda = 10.^(0:0.5:4);
fs = 64;
rec_dur = 180;
tmin = 0;
tmax = 300;
subs = 1:19;

% Load stimulus features
load([folderCND 'dataStim.mat'])

% extract single features
stimE = dataStim(1,:);
stimS = dataStim(2,:);
stimF = dataStim(3,:);

% build multi feature
for tt = 1:20
    stimSF{1,tt} = [stimS{1,tt} stimF{1,tt}];
end

for sub = subs
    %% Load EEG data
    
    % get preprocessed data
    load([folderCND 'pre_dataSub' num2str(sub) '.mat'],'eeg')
    
    % resample, trim excess, zscore
    for tt = 1:length(eeg.data)
        eeg.data{tt} = resample(eeg.data{tt},fs,eeg.fs);
        eeg.data{tt} = eeg.data{tt}(1:size(dataStim{1,tt},1),:);
        eeg.data{tt} = zscore(eeg.data{tt});
    end
    eeg.fs = fs;
    
    %% Single-feature models
    
    % E = envelope
    E = mTRFcrossval(stimE,eeg.data,eeg.fs,1,tmin,tmax,lambda);
    [~,Ie] = max(mean(mean(E.r,1),3));
    
    % S = spectrogram
    S = mTRFcrossval(stimS,eeg.data,eeg.fs,1,tmin,tmax,lambda);
    [~,Is] = max(mean(mean(S.r,1),3));
    
    % F = phonetic features
    F = mTRFcrossval(stimF,eeg.data,eeg.fs,1,tmin,tmax,lambda);
    [~,If] = max(mean(mean(F.r,1),3));
    
    trf.E = E;
    trf.S = S;
    trf.F = F;
    trf.Ie = Ie;
    trf.Is = Is;
    trf.If = If;
    
    %% Multi-feature model
    
    % SF = spectrogram + phonetic features
    SF = mTRFcrossval(stimSF,eeg.data,eeg.fs,1,tmin,tmax,lambda);
    [~,Isf] = max(mean(mean(SF.r,1),3));
    
    % SFO = spectrogram + phonetic features + word onsets
    %SFO = mTRFcrossval(stimSFO,eeg.data,eeg.fs,1,tmin,tmax,lambda);
    %[avgE,Ie] = max(mean(mean(E.r,1),3));
    
    trf.SF = SF;
    trf.Isf = Isf;
    
    %%% SIMPLE TOPOS
    openfig('figs/SimpleTopos.fig')
    
    %%% OG FIGURE
    openfig('figs/OriginalGio.fig')
    
    %% Shuffing full feature
    
    % Difference of shuffling stimulus before or model after mTRFtrain
    
    % S+F_shuf vs S+F
    for ii = 1:50
        stimSF_shuf = stimSF;
        stimS_shufF = stimSF;
        
        tr_shuf = randperm(size(stimS,2));
        
        for tt = 1:size(stimS,2)
            stimSF_shuf{tt}(:,17:end) = stimSF{tr_shuf(tt)}(:,17:end);
            stimS_shufF{tt}(:,1:16) = stimSF{tr_shuf(tt)}(:,1:16);
        end
        
        SF_shuf(ii) = mTRFcrossval(stimSF_shuf,eeg.data,eeg.fs,1,tmin,tmax,lambda);
        S_shufF(ii) = mTRFcrossval(stimS_shufF,eeg.data,eeg.fs,1,tmin,tmax,lambda);
        
        [~,Is_shuff(ii)] = max(mean(mean(SF_shuf(ii).r,1),3));
        [~,Isf_shuf(ii)] = max(mean(mean(S_shufF(ii).r,1),3));
    end
    
    
    trf.SF_shuf = SF_shuf;
    trf.S_shufF = S_shufF;
    trf.Is_shuff = Is_shuff;
    trf.Isf_shuf = Isf_shuf;
    
    %%% FULL FEATURE SHUFFLE
    openfig('figs/FullFeatShuf.fig')
    
    %% Shuffling individual feature
    
    nfreq = size(stimS{1},2);
    for ii = 1:nfreq
        stimS_shuf = stimS;
        tr_shuf = randperm(size(stimS,2));
        for tt = 1;size(stimS,2)
            stimS_shuf{tt}(:,ii) = stimS{tr_shuf(tt)}(:,ii);
        end
        
        S_indv(ii) = mTRFcrossval(stimS_shuf,eeg.data,eeg.fs,1,tmin,tmax,lambda);
        [~,Is_indv(ii)] = max(mean(mean(S_indv(ii).r,1),3));
    end
    
    
    trf.S_indv = S_indv;
    trf.Is_indv = Is_indv;
    
    %%% INDV FEATURE SHUFFLE
    openfig('figs/IndvFeatShuf.fig')
    
    %% Dealing with correlated features, partial regression
    
    resp_noS = mTRFpartial(stimS,eeg.data,eeg.fs,1,tmin,tmax,lambda);
    F_noS = mTRFcrossval(stimF,resp_noS,eeg.fs,1,tmin,tmax,lambda);
    [~,Ifns] = max(mean(mean(F_noS.r,1),3));
    
    resp_noF = mTRFpartial(stimF,eeg.data,eeg.fs,1,tmin,tmax,lambda);
    S_noF = mTRFcrossval(stimS,resp_noF,eeg.fs,1,tmin,tmax,lambda);
    [~,Isnf] = max(mean(mean(S_noF.r,1),3));
    
    trf.F_noS = F_noS;
    trf.Ifns = Ifns;
    
    trf.S_noF = S_noF;
    trf.Isnf = Isnf;
    
    
    %%% FULL MINUS PARTIAL
    openfig('figs/FullMinusPartial.fig')
    
    %% Banded regression see Nunez-Elizalde et al 2018
    % for when your stimulus features are so different and you need to
    % account for that.
    grouping = [ones(1,size(stimS{1},2)) 2*ones(1,size(stimF{1},2))];
    [SFb,l_min]=mTRFcvsearch(stimSF,eeg.data,eeg.fs,1,tmin,tmax,grouping);
    [~,Isfb] = max(mean(mean(SFb.r,1),3));
    
    trf.SFb = SFb;
    trf.l_min = l_min;
    
    %%% BANDED LAMBDA FIGURE
    openfig('figs/BandedLambda.fig')
    openfig('figs/BandedTopo.fig')
    
    %% Dealing with correlated features, full and reduce modeling, after Nunez-Elizalde et al 2018
    
    for tt = 1:20
        
        tr_trials = 1:20; tr_trials(tt) = [];
        
        % First we'll xval a model using a subset of trials and all the features we want to account for
        % Currently this won't work because banded ridge isn't finished.
        SF_ = mTRFcrossval(stimSF(1,tr_trials),eeg.data(1,tr_trials),eeg.fs,1,tmin,tmax,lambda);
        [~,I] = max(mean(mean(SF_.r,1),3));
        
        % Fit the full model on the remaining trial
        modelSF = mTRFtrain(stimSF(1,tt),eeg.data(1,tt),eeg.fs,1,tmin,tmax,lambda(I));
        
        % Extract the features of interest
        modelS_ = modelSF; modelS_.w(17:end,:,:) = [];
        modelF_ = modelSF; modelF_.w(1:16,:,:) = [];
        
        % Predict remaining trial using remaining feature
        [~,temp] = mTRFpredict(stimS(1,tt),eeg.data(1,tt),modelS_);
        S_.r(tt,:,:) = temp.r; S_.err(tt,:,:) = temp.err;
        [~,temp] = mTRFpredict(stimF(1,tt),eeg.data(1,tt),modelF_);
        F_.r(tt,:,:) = temp.r; F_.err(tt,:,:) = temp.err;
    end
    
    [~,Is_] = max(mean(mean(S_.r,1),3));
    [~,If_] = max(mean(mean(F_.r,1),3));
    
    
    trf.S_ = S_;
    trf.Is_ = Is_;
    
    trf.F_ = F_;
    trf.If_ = If_;
    
    %%% NO FIGURE CAUSE BANDED RIDGE NOT FULLY FUNCTIONAL
    
    %% Save Subject
    
    save([folderCND 'trf_dataSub',num2str(sub),'_',num2str(tmin),'-',num2str(tmax),'.mat'],'trf')
    
    
end
