SVM classifier on paired data using CANlab tools

Support Vector Machines are a fast and efficient way to identify multivariate patterns that separate two conditions. For example, we might want to identify brain patterns that predict (or 'decode'):
SVMs are called "maximum margin classifiers" because they identify a hyperplane (think of a 2-D boundary like a curtain, but in multidimensional space) that maximizes the separation (or "margin", in the sense that there is a buffer of empty space) between images belonging to the two classes. The outcome (or response) data is often labeled with values of 1 and -1, i.e., "on"/"off" labels for the classes.
SVMs can be run on any set of images using whole-brain patterns, multiple regions of interest, or other derived features (input variables). The SVM will try to define a hyperplane (boundary) that separates the classes of all the images (or other measures) fed into it. It delivers two outputs:
The weight map plus the intercept is the "model".
The scores can be evaluated for accuracy using cross-validation or by applying the model to an independent test dataset. Applying the model to a dataset with independent errors is the only way to prevent overfitting and provide a valid estimate of the model's classification performance.
The comparisons between classes can either be between-person, where different examples of "on" and "off" images come from different participants, or within-person (or paired), where there is one or more example of an "on" and and "off" condition for each participant. The SVM classifier will simply try to separate the classes based on the features fed into it, regardless of the identity of the participant. But when it delivers scores back, we can keep track of which scores came from the same person, and compute a within-person, forced-choice accuracy. This asesses whether, given two images from the same participant, I can guess which one was collected during the "on" and which during the "off" condition. This is usually quite a bit higher than the single-interval accuracy, which just makes a guess for each image based on the absolute SVM score. There are many individual differences that can influence the overall score (e.g., vasculature, overall BOLD magnitude, variations in brain morphology and normalization) and lead to poor single-interval accuracy, but many of these sources drop out if we are able to compare paired images from the same person.
In this lab, we'll use the function canlab_run_paired_SVM to train on a mix of images, but calculate the forced-choice accuracy within a person.

Set up directory and load files

dat = load_image_set('Kragel18_alldata');
Loading /Users/torwager/Dropbox (Dartmouth College)/COURSES/Courses_Dartmouth/2021_3_Spring_fMRI_Class/PSYC60_Shared_resources_for_students/datasets/kragel_2018_nat_neurosci_270_subjects_test_images.mat Loaded images: ThermalPain1 ThermalPain1 ThermalPain1 ThermalPain1 ThermalPain1 ThermalPain1 ThermalPain1 ThermalPain1 ThermalPain1 ThermalPain1 ThermalPain1 ThermalPain1 ThermalPain1 ThermalPain1 ThermalPain1 ThermalPain2 ThermalPain2 ThermalPain2 ThermalPain2 ThermalPain2 ThermalPain2 ThermalPain2 ThermalPain2 ThermalPain2 ThermalPain2 ThermalPain2 ThermalPain2 ThermalPain2 ThermalPain2 ThermalPain2 VisceralPain1 VisceralPain1 VisceralPain1 VisceralPain1 VisceralPain1 VisceralPain1 VisceralPain1 VisceralPain1 VisceralPain1 VisceralPain1 VisceralPain1 VisceralPain1 VisceralPain1 VisceralPain1 VisceralPain1 VisceralPain2 VisceralPain2 VisceralPain2 VisceralPain2 VisceralPain2 VisceralPain2 VisceralPain2 VisceralPain2 VisceralPain2 VisceralPain2 VisceralPain2 VisceralPain2 VisceralPain2 VisceralPain2 VisceralPain2 MechanicalPain1 MechanicalPain1 MechanicalPain1 MechanicalPain1 MechanicalPain1 MechanicalPain1 MechanicalPain1 MechanicalPain1 MechanicalPain1 MechanicalPain1 MechanicalPain1 MechanicalPain1 MechanicalPain1 MechanicalPain1 MechanicalPain1 MechanicalPain2 MechanicalPain2 MechanicalPain2 MechanicalPain2 MechanicalPain2 MechanicalPain2 MechanicalPain2 MechanicalPain2 MechanicalPain2 MechanicalPain2 MechanicalPain2 MechanicalPain2 MechanicalPain2 MechanicalPain2 MechanicalPain2 Cog WM1 Cog WM1 Cog WM1 Cog WM1 Cog WM1 Cog WM1 Cog WM1 Cog WM1 Cog WM1 Cog WM1 Cog WM1 Cog WM1 Cog WM1 Cog WM1 Cog WM1 Cog WM2 Cog WM2 Cog WM2 Cog WM2 Cog WM2 Cog WM2 Cog WM2 Cog WM2 Cog WM2 Cog WM2 Cog WM2 Cog WM2 Cog WM2 Cog WM2 Cog WM2 Cog Inhib1 Cog Inhib1 Cog Inhib1 Cog Inhib1 Cog Inhib1 Cog Inhib1 Cog Inhib1 Cog Inhib1 Cog Inhib1 Cog Inhib1 Cog Inhib1 Cog Inhib1 Cog Inhib1 Cog Inhib1 Cog Inhib1 Cog Inhib2 Cog Inhib2 Cog Inhib2 Cog Inhib2 Cog Inhib2 Cog Inhib2 Cog Inhib2 Cog Inhib2 Cog Inhib2 Cog Inhib2 Cog Inhib2 Cog Inhib2 Cog Inhib2 Cog Inhib2 Cog Inhib2 Cog RespSel1 Cog RespSel1 Cog RespSel1 Cog RespSel1 Cog RespSel1 Cog RespSel1 Cog RespSel1 Cog RespSel1 Cog RespSel1 Cog RespSel1 Cog RespSel1 Cog RespSel1 Cog RespSel1 Cog RespSel1 Cog RespSel1 Cog RespSel2 Cog RespSel2 Cog RespSel2 Cog RespSel2 Cog RespSel2 Cog RespSel2 Cog RespSel2 Cog RespSel2 Cog RespSel2 Cog RespSel2 Cog RespSel2 Cog RespSel2 Cog RespSel2 Cog RespSel2 Cog RespSel2 Emotion_Aversiveimages1 Emotion_Aversiveimages1 Emotion_Aversiveimages1 Emotion_Aversiveimages1 Emotion_Aversiveimages1 Emotion_Aversiveimages1 Emotion_Aversiveimages1 Emotion_Aversiveimages1 Emotion_Aversiveimages1 Emotion_Aversiveimages1 Emotion_Aversiveimages1 Emotion_Aversiveimages1 Emotion_Aversiveimages1 Emotion_Aversiveimages1 Emotion_Aversiveimages1 Emotion_Aversiveimages2 Emotion_Aversiveimages2 Emotion_Aversiveimages2 Emotion_Aversiveimages2 Emotion_Aversiveimages2 Emotion_Aversiveimages2 Emotion_Aversiveimages2 Emotion_Aversiveimages2 Emotion_Aversiveimages2 Emotion_Aversiveimages2 Emotion_Aversiveimages2 Emotion_Aversiveimages2 Emotion_Aversiveimages2 Emotion_Aversiveimages2 Emotion_Aversiveimages2 Emotion_Rejection1 Emotion_Rejection1 Emotion_Rejection1 Emotion_Rejection1 Emotion_Rejection1 Emotion_Rejection1 Emotion_Rejection1 Emotion_Rejection1 Emotion_Rejection1 Emotion_Rejection1 Emotion_Rejection1 Emotion_Rejection1 Emotion_Rejection1 Emotion_Rejection1 Emotion_Rejection1 Emotion_VicariousPain2 Emotion_VicariousPain2 Emotion_VicariousPain2 Emotion_VicariousPain2 Emotion_VicariousPain2 Emotion_VicariousPain2 Emotion_VicariousPain2 Emotion_VicariousPain2 Emotion_VicariousPain2 Emotion_VicariousPain2 Emotion_VicariousPain2 Emotion_VicariousPain2 Emotion_VicariousPain2 Emotion_VicariousPain2 Emotion_VicariousPain2 Emotion_AversiveSound1 Emotion_AversiveSound1 Emotion_AversiveSound1 Emotion_AversiveSound1 Emotion_AversiveSound1 Emotion_AversiveSound1 Emotion_AversiveSound1 Emotion_AversiveSound1 Emotion_AversiveSound1 Emotion_AversiveSound1 Emotion_AversiveSound1 Emotion_AversiveSound1 Emotion_AversiveSound1 Emotion_AversiveSound1 Emotion_AversiveSound1 Emotion_AversiveSound2 Emotion_AversiveSound2 Emotion_AversiveSound2 Emotion_AversiveSound2 Emotion_AversiveSound2 Emotion_AversiveSound2 Emotion_AversiveSound2 Emotion_AversiveSound2 Emotion_AversiveSound2 Emotion_AversiveSound2 Emotion_AversiveSound2 Emotion_AversiveSound2 Emotion_AversiveSound2 Emotion_AversiveSound2 Emotion_AversiveSound2
This loads the data we'll use. The dataset is from Kragel et al. 2018, Nature Neuroscience. It consists of 270 participants systematically sampled from 18 studies of 3 domains:
Each subdomain includes two studies per type, with n = 15 per study. The .metadata_table attribute contains a Matlab table object with the domain, subdomain, and study indicator.
dat.metadata_table(1:5, :)
Warning: Error occurred while executing the listener callback for event FLUSH_OUTPUTS defined for class matlab.internal.editor.OutputsManager:
Unable to perform assignment because dot indexing is not supported for variables of this type.

)
a(end+1).(field) = b;

Error in tabular/varfun (line 194)
b_dataVarNames = a.varDim.makeValidName(append(funPrefix,'_',a_varnames(dataVars)),'warnLength');

Error in matlab.internal.editor.interactiveVariables.InteractiveTablesPackager.isColNumExceedsLimitOrHasEmpties

Error in matlab.internal.editor.interactiveVariables.InteractiveTablesPackager.packageVarTable

Error in matlab.internal.editor.interactiveVariables.InteractiveVariablesPackager.packageVarInteractive

Error in matlab.internal.editor.VariableOutputPackager.packageOutput

Error in matlab.internal.editor.OutputPackager.packageOneOutput

Error in matlab.internal.editor.OutputPackager.packageEachOutput

Error in matlab.internal.editor.OutputPackager.packageOutputs

Error in matlab.internal.editor.EvaluationOutputsService.outputStreamEvent

Error in matlab.internal.editor.EvaluationOutputsService>@(src,ev)EvaluationOutputsService.outputStreamEvent(src,ev,editorId,requestId,filePath)

Error in matlab.internal.editor.OutputsManager/streamOutputsCallback

Error in matlab.internal.editor.OutputsManager>@(varargin)obj.streamOutputsCallback(varargin{:})

Error in tabular/display (line 18)
matlab.internal.language.signalVariableDisplay(obj, name)

Error in LiveEditorEvaluationHelperE1564456211 (line 4)
dat.metadata_table

Error in matlab.internal.editor.evaluateRegions

Error in matlab.internal.editor.EvaluationOutputsService.evalRegions

Look at the data and consider scaling options

It's always a good idea to look at your data! With this dataset, because images come from different studies, there is no guarantee they are on the same scale. We'll do a simple transformation here, z-scoring each participant's data image. This subtracts the mean and divides by the standard deviation (across voxels) of each image. After z-scoring, effectively, we're analyzing the pattern of values across voxels rather than their absolute intensity.
Z-scoring treats the images independently and does not depend on estimating any parameters (e.g., means, ranks) across the group, which is crucial for maintaining validity when using cross-validation.
%% z-score each participant
dat = rescale(dat, 'zscoreimages');

Questions to answer

  1. How important is scaling here? What was wrong with the data, if anything, before scaling?
  2. What kind of information that could be useful for classification might you be losing by scaling each image?

Define training/cross-validation and test (generalization) datasets

We'll use some Matlab code for manipulating sets to pull out a training set for the classification Pain vs. Cognitive Control. We'll train on 3 studies of each, across 3 subdomains each.

Training dataset

wh_studies = [1 3 5 7 9 11];
wh_subjects = ismember(dat.metadata_table.Studynumber, wh_studies);
% List the unique subdomains
unique(dat.metadata_table.Subdomain(wh_subjects))
ans = 6×1 cell
'Inhibition'
'Mechanical'
'ResponseSelect'
'Thermal'
'Visceral'
'WorkingMem'
Now we'll define a training_dat object, selecting only the included subjects. We do this by calling remove_empty() to exclude those that are not (~) in the list.
training_dat = dat;
training_dat = remove_empty(training_dat, [], ~wh_subjects);
% We may have to update some fields manually, like the metadata table
training_dat.metadata_table = training_dat.metadata_table(wh_subjects, :);
To train the SVM, we have different options. We can:
We'll use the first of these.
wh_studies = [1 3 5];
train_plus = ismember(dat.metadata_table.Studynumber, wh_studies);
wh_studies = [7 9 11];
train_minus = ismember(dat.metadata_table.Studynumber, wh_studies);
% generate vector of outcome labels, 1 and -1.
% Cast as double because we need numeric (not logical) inputs.
train_Y = double(train_plus) - double(train_minus);
training_dat.Y = train_Y(wh_subjects);

Test dataset

We'll use the same type of code to define a test dataset, with different studies from the same subdomain:
wh_studies = [2 4 6];
test_plus = ismember(dat.metadata_table.Studynumber, wh_studies);
wh_studies = [8 10 12];
test_minus = ismember(dat.metadata_table.Studynumber, wh_studies);
% generate vector of outcome labels, 1 and -1.
% Cast as double because we need numeric (not logical) inputs.
test_Y = double(test_plus) - double(test_minus);
wh_subjects = test_Y ~= 0; % logical vector for any subject in test set
test_dat = dat;
test_dat = remove_empty(test_dat, [], ~wh_subjects);
test_dat.Y = test_Y(wh_subjects);

Questions to answer

  1. Why is it important to separate training and test datasets?
  2. Have we done this adequately here? Why or why not?
  3. What are the advantages, if any, of including different kinds of studies with different subdomains in our training set?
  4. What are the disadvantages of the above, if any?
  5. What alternative splits of training and test data could make sense here, and why would they be advantageous?

Train the SVM to discriminate Pain vs. Cognitive Control

Our first goal is to train a classifier that separates pain from cog control:
[cverr, stats, optout] = predict(training_dat, 'cv_svm');
Cross-validated prediction with algorithm cv_svm, 5 folds Training...training svm kernel linear C=1 optimizer=andre.... Testing... Done in 1.44 sec Completed fit for all data in: 0 hours 0 min 2 secs Training...training svm kernel linear C=1 optimizer=andre.... Testing... Done in 0.66 sec Fold 1/5 done in: 0 hours 0 min 1 sec Training...training svm kernel linear C=1 optimizer=andre.... Testing... Done in 0.55 sec Fold 2/5 done in: 0 hours 0 min 1 sec Training...training svm kernel linear C=1 optimizer=andre.... Testing... Done in 0.58 sec Fold 3/5 done in: 0 hours 0 min 1 sec Training...training svm kernel linear C=1 optimizer=andre.... Testing... Done in 0.60 sec Fold 4/5 done in: 0 hours 0 min 1 sec Training...training svm kernel linear C=1 optimizer=andre.... Testing... Done in 0.59 sec Fold 5/5 done in: 0 hours 0 min 1 sec Total Elapsed Time = 0 hours 0 min 6 sec Number of unique values in dataset: 201275 Bit rate: 17.62 bits
predict() runs 5-fold cross-validation by default. There are optional inputs for:
Other related CANlab functions are xval_SVM( ), xval_SVR( ), canlab_run_paired_SVM( ), and other functions that start with xval_
paired_d is the cross-validated effect size in paired (forced-choice) comparison. 0.5 is a medium effect, 0.7 a large effect, and anything larger is "very large".
stats contains a stats output structure for the SVM. This includes specific fields:

Display the cross-validated results

Display the unthresholded weight map:
We can extract statistics from the stats structure in different ways.
Confusion matrix. Let's get a confusion matrix, and a table for each true class. The confusion matrix shows the proportions (or counts) of each true class (rows) assigned to each class (columns). It can thus reveal what kinds of errors the model is making.
[m,aprime,corr,far,missclass, mprop, statstable] = confusion_matrix(stats.Y, stats.yfit);
% Create a table object from the matrix mprop, which shows proportions of
% each true class (rows) assigned to each class (columns)
mytable = array2table(mprop, 'VariableNames', {'Pred Cog' 'Pred Pain'}, 'RowNames', {'Actual Cog' 'Actual Pain'});
disp(mytable)
Pred Cog Pred Pain ________ _________ Actual Cog 0.88889 0.11111 Actual Pain 0.088889 0.91111
disp(statstable)
Freq Perc_cases Sens Spec PPV Aprime ____ __________ _______ _______ _______ ______ 45 0.5 0.88889 0.91111 0.90909 2.5683 45 0.5 0.91111 0.88889 0.8913 2.5683
The ROC plot. Display the performance in a Receiver Operating Characteristic curve. This plots the Specificity against 1 - Sensitivity (the False Alarm rate). It's possible to have 100% sensitivity by always classifying an observation as a "yes" (1), making 100% false alarms. Likewise, It's possible to have 100% specificity by always classifying an observation as a "no" (-1). The ROC plot displays the tradeoff curve between the two. A concave curve (high Sensitivity with low [1 - Specificity], i.e., high Specificity) shows good classification performance and indicates true predictive signal. A straight line or a convex curve indicates no true predictive signal. ROC_plot() will also choose a threshold that maximizes either the overall accuracy or the balanced accuracy (the average accuracy across classes, even if some classes have fewer observations). The latter is recommended. Here is some of the output of ROC_plot(), saved in a structure:
create_figure('ROC')
ans =
Figure (ROC) with properties: Number: 1 Name: 'ROC' Color: [1 1 1] Position: [561 529 560 420] Units: 'pixels' Show all properties
ROC = roc_plot(stats.yfit, logical(stats.Y > 0), 'color', 'r');
ROC_PLOT Output: Single-interval, Optimal overall accuracy Threshold: -1.00 Sens: 91% CI(82%-98%) Spec: 89% CI(79%-98%) PPV: 89% CI(79%-98%) Nonparametric AUC: 0.85 Parametric d_a: 2.64 Accuracy: 90% +- 3.2% (SE), P = 0.000000
Plot the individual subjects. This is always useful to get a better picture of what the data look like!
create_figure('subjects');
plot(stats.dist_from_hyperplane_xval, 'o');
plot_horizontal_line(0);
xlabel('Participant'); ylabel('Classifier score');
% Plot the scores
dat_to_plot = {stats.dist_from_hyperplane_xval(stats.Y > 0) stats.dist_from_hyperplane_xval(stats.Y < 0)};
create_figure('SVM distance scores by class');
barplot_columns(dat_to_plot, 'title', 'Training data SVM scores', 'nofigure', 'colors', {[.7 .2 .2] [.5 .3 .7]});
Column 1: Column 2: --------------------------------------------- Tests of column means against zero --------------------------------------------- Name Mean_Value Std_Error T P Cohens_d ___________ __________ _________ _______ __________ ________ {'Col 1'} 0.61095 0.060264 10.138 4.3724e-13 1.5113 {'Col 2'} -0.56656 0.068587 -8.2604 1.7104e-10 -1.2314
set(gca, 'XTickLabel', {'Pain' 'Cog Control'});
ylabel('Cross-validated SVM distance');
xlabel('Domain')

Questions to answer

  1. How accurate is the model? How sensitive and specific? Describe in words what these values mean.
  2. What is the SVM classificaiton boundary here for separating pain from cognitive control? If I have an observation with an SVM distance score of 0.5, for example, what class would the model predict? How about a score of -0.7?
  3. How many subjects are misclassified here? Do they come more from some subdomains than others? (Note: you will want to create plots to show this that are not in the code example!)

Display the SVM map and results

Display the unthresholded weight map:
create_figure('weight map'); axis off;
montage(stats.weight_obj);
Setting up fmridisplay objects
sagittal montage: 4038 voxels displayed, 197237 not displayed on these slices
sagittal montage: 4000 voxels displayed, 197275 not displayed on these slices
sagittal montage: 3849 voxels displayed, 197426 not displayed on these slices
axial montage: 28960 voxels displayed, 172315 not displayed on these slices
axial montage: 31113 voxels displayed, 170162 not displayed on these slices