SVM classifier on paired data using CANlab tools
Support Vector Machines are a fast and efficient way to identify multivariate patterns that separate two conditions. For example, we might want to identify brain patterns that predict (or 'decode'):
- whether a working memory task was difficult or easy
- whether a person will remember or forget an image
- whether a person is experiencing high or low pain
- whether a person is actively reappraising (regulating) emotion or not
SVMs are called "maximum margin classifiers" because they identify a hyperplane (think of a 2-D boundary like a curtain, but in multidimensional space) that maximizes the separation (or "margin", in the sense that there is a buffer of empty space) between images belonging to the two classes. The outcome (or response) data is often labeled with values of 1 and -1, i.e., "on"/"off" labels for the classes.
SVMs can be run on any set of images using whole-brain patterns, multiple regions of interest, or other derived features (input variables). The SVM will try to define a hyperplane (boundary) that separates the classes of all the images (or other measures) fed into it. It delivers two outputs:
- A weight map, which is a value for each feature (brain voxel in an image). The weights specify a linear combination that is orthogonal to the hyperplane, so that a higher combination indicates the "on" class and a lower combination indicates the "off" class. If brain images are fed in as input features, the weight map will be a voxel-wise brain map.
- Scores for each image. The scores are the dot product of the weight map and the input data for each observation (i.e., the linear combination discussed above), plus an intercept, or "offset" value. The offset is estimated so that in the training data, scores > 0 are predicted to be "on" and scores < 0 are predicted to be "off".
The weight map plus the intercept is the "model".
The scores can be evaluated for accuracy using cross-validation or by applying the model to an independent test dataset. Applying the model to a dataset with independent errors is the only way to prevent overfitting and provide a valid estimate of the model's classification performance.
The comparisons between classes can either be between-person, where different examples of "on" and "off" images come from different participants, or within-person (or paired), where there is one or more example of an "on" and and "off" condition for each participant. The SVM classifier will simply try to separate the classes based on the features fed into it, regardless of the identity of the participant. But when it delivers scores back, we can keep track of which scores came from the same person, and compute a within-person, forced-choice accuracy. This asesses whether, given two images from the same participant, I can guess which one was collected during the "on" and which during the "off" condition. This is usually quite a bit higher than the single-interval accuracy, which just makes a guess for each image based on the absolute SVM score. There are many individual differences that can influence the overall score (e.g., vasculature, overall BOLD magnitude, variations in brain morphology and normalization) and lead to poor single-interval accuracy, but many of these sources drop out if we are able to compare paired images from the same person.
In this lab, we'll use the function canlab_run_paired_SVM to train on a mix of images, but calculate the forced-choice accuracy within a person.
Set up directory and load files
dat = load_image_set('Kragel18_alldata');
Loading /Users/torwager/Dropbox (Dartmouth College)/COURSES/Courses_Dartmouth/2021_3_Spring_fMRI_Class/PSYC60_Shared_resources_for_students/datasets/kragel_2018_nat_neurosci_270_subjects_test_images.mat
Loaded images:
ThermalPain1
ThermalPain1
ThermalPain1
ThermalPain1
ThermalPain1
ThermalPain1
ThermalPain1
ThermalPain1
ThermalPain1
ThermalPain1
ThermalPain1
ThermalPain1
ThermalPain1
ThermalPain1
ThermalPain1
ThermalPain2
ThermalPain2
ThermalPain2
ThermalPain2
ThermalPain2
ThermalPain2
ThermalPain2
ThermalPain2
ThermalPain2
ThermalPain2
ThermalPain2
ThermalPain2
ThermalPain2
ThermalPain2
ThermalPain2
VisceralPain1
VisceralPain1
VisceralPain1
VisceralPain1
VisceralPain1
VisceralPain1
VisceralPain1
VisceralPain1
VisceralPain1
VisceralPain1
VisceralPain1
VisceralPain1
VisceralPain1
VisceralPain1
VisceralPain1
VisceralPain2
VisceralPain2
VisceralPain2
VisceralPain2
VisceralPain2
VisceralPain2
VisceralPain2
VisceralPain2
VisceralPain2
VisceralPain2
VisceralPain2
VisceralPain2
VisceralPain2
VisceralPain2
VisceralPain2
MechanicalPain1
MechanicalPain1
MechanicalPain1
MechanicalPain1
MechanicalPain1
MechanicalPain1
MechanicalPain1
MechanicalPain1
MechanicalPain1
MechanicalPain1
MechanicalPain1
MechanicalPain1
MechanicalPain1
MechanicalPain1
MechanicalPain1
MechanicalPain2
MechanicalPain2
MechanicalPain2
MechanicalPain2
MechanicalPain2
MechanicalPain2
MechanicalPain2
MechanicalPain2
MechanicalPain2
MechanicalPain2
MechanicalPain2
MechanicalPain2
MechanicalPain2
MechanicalPain2
MechanicalPain2
Cog WM1
Cog WM1
Cog WM1
Cog WM1
Cog WM1
Cog WM1
Cog WM1
Cog WM1
Cog WM1
Cog WM1
Cog WM1
Cog WM1
Cog WM1
Cog WM1
Cog WM1
Cog WM2
Cog WM2
Cog WM2
Cog WM2
Cog WM2
Cog WM2
Cog WM2
Cog WM2
Cog WM2
Cog WM2
Cog WM2
Cog WM2
Cog WM2
Cog WM2
Cog WM2
Cog Inhib1
Cog Inhib1
Cog Inhib1
Cog Inhib1
Cog Inhib1
Cog Inhib1
Cog Inhib1
Cog Inhib1
Cog Inhib1
Cog Inhib1
Cog Inhib1
Cog Inhib1
Cog Inhib1
Cog Inhib1
Cog Inhib1
Cog Inhib2
Cog Inhib2
Cog Inhib2
Cog Inhib2
Cog Inhib2
Cog Inhib2
Cog Inhib2
Cog Inhib2
Cog Inhib2
Cog Inhib2
Cog Inhib2
Cog Inhib2
Cog Inhib2
Cog Inhib2
Cog Inhib2
Cog RespSel1
Cog RespSel1
Cog RespSel1
Cog RespSel1
Cog RespSel1
Cog RespSel1
Cog RespSel1
Cog RespSel1
Cog RespSel1
Cog RespSel1
Cog RespSel1
Cog RespSel1
Cog RespSel1
Cog RespSel1
Cog RespSel1
Cog RespSel2
Cog RespSel2
Cog RespSel2
Cog RespSel2
Cog RespSel2
Cog RespSel2
Cog RespSel2
Cog RespSel2
Cog RespSel2
Cog RespSel2
Cog RespSel2
Cog RespSel2
Cog RespSel2
Cog RespSel2
Cog RespSel2
Emotion_Aversiveimages1
Emotion_Aversiveimages1
Emotion_Aversiveimages1
Emotion_Aversiveimages1
Emotion_Aversiveimages1
Emotion_Aversiveimages1
Emotion_Aversiveimages1
Emotion_Aversiveimages1
Emotion_Aversiveimages1
Emotion_Aversiveimages1
Emotion_Aversiveimages1
Emotion_Aversiveimages1
Emotion_Aversiveimages1
Emotion_Aversiveimages1
Emotion_Aversiveimages1
Emotion_Aversiveimages2
Emotion_Aversiveimages2
Emotion_Aversiveimages2
Emotion_Aversiveimages2
Emotion_Aversiveimages2
Emotion_Aversiveimages2
Emotion_Aversiveimages2
Emotion_Aversiveimages2
Emotion_Aversiveimages2
Emotion_Aversiveimages2
Emotion_Aversiveimages2
Emotion_Aversiveimages2
Emotion_Aversiveimages2
Emotion_Aversiveimages2
Emotion_Aversiveimages2
Emotion_Rejection1
Emotion_Rejection1
Emotion_Rejection1
Emotion_Rejection1
Emotion_Rejection1
Emotion_Rejection1
Emotion_Rejection1
Emotion_Rejection1
Emotion_Rejection1
Emotion_Rejection1
Emotion_Rejection1
Emotion_Rejection1
Emotion_Rejection1
Emotion_Rejection1
Emotion_Rejection1
Emotion_VicariousPain2
Emotion_VicariousPain2
Emotion_VicariousPain2
Emotion_VicariousPain2
Emotion_VicariousPain2
Emotion_VicariousPain2
Emotion_VicariousPain2
Emotion_VicariousPain2
Emotion_VicariousPain2
Emotion_VicariousPain2
Emotion_VicariousPain2
Emotion_VicariousPain2
Emotion_VicariousPain2
Emotion_VicariousPain2
Emotion_VicariousPain2
Emotion_AversiveSound1
Emotion_AversiveSound1
Emotion_AversiveSound1
Emotion_AversiveSound1
Emotion_AversiveSound1
Emotion_AversiveSound1
Emotion_AversiveSound1
Emotion_AversiveSound1
Emotion_AversiveSound1
Emotion_AversiveSound1
Emotion_AversiveSound1
Emotion_AversiveSound1
Emotion_AversiveSound1
Emotion_AversiveSound1
Emotion_AversiveSound1
Emotion_AversiveSound2
Emotion_AversiveSound2
Emotion_AversiveSound2
Emotion_AversiveSound2
Emotion_AversiveSound2
Emotion_AversiveSound2
Emotion_AversiveSound2
Emotion_AversiveSound2
Emotion_AversiveSound2
Emotion_AversiveSound2
Emotion_AversiveSound2
Emotion_AversiveSound2
Emotion_AversiveSound2
Emotion_AversiveSound2
Emotion_AversiveSound2
- pain of three subtypes: thermal, visceral, and mechanical
- cognitive control of three subtypes: working memory, response competition, and response selection
- negative affect of three subtypes: aversive images from the IAPS set, social affect (romantic rejection and vicarious pain), and aversive sounds from the IADS set.
Each subdomain includes two studies per type, with n = 15 per study. The .metadata_table attribute contains a Matlab table object with the domain, subdomain, and study indicator.
dat.metadata_table(1:5, :)
Look at the data and consider scaling options
It's always a good idea to look at your data! With this dataset, because images come from different studies, there is no guarantee they are on the same scale. We'll do a simple transformation here, z-scoring each participant's data image. This subtracts the mean and divides by the standard deviation (across voxels) of each image. After z-scoring, effectively, we're analyzing the pattern of values across voxels rather than their absolute intensity.
Z-scoring treats the images independently and does not depend on estimating any parameters (e.g., means, ranks) across the group, which is crucial for maintaining validity when using cross-validation.
%% z-score each participant
dat = rescale(dat, 'zscoreimages');
Questions to answer
- How important is scaling here? What was wrong with the data, if anything, before scaling?
- What kind of information that could be useful for classification might you be losing by scaling each image?
Define training/cross-validation and test (generalization) datasets
We'll use some Matlab code for manipulating sets to pull out a training set for the classification Pain vs. Cognitive Control. We'll train on 3 studies of each, across 3 subdomains each.
Training dataset
wh_studies = [1 3 5 7 9 11];
wh_subjects = ismember(dat.metadata_table.Studynumber, wh_studies);
% List the unique subdomains
unique(dat.metadata_table.Subdomain(wh_subjects))
'Inhibition'
'Mechanical'
'ResponseSelect'
'Thermal'
'Visceral'
'WorkingMem'
Now we'll define a training_dat object, selecting only the included subjects. We do this by calling remove_empty() to exclude those that are not (~) in the list.
training_dat = remove_empty(training_dat, [], ~wh_subjects);
% We may have to update some fields manually, like the metadata table
training_dat.metadata_table = training_dat.metadata_table(wh_subjects, :);
To train the SVM, we have different options. We can:
- Use the predict( ) method after assigning 1 and -1 labels to training_dat.Y
- Use xval_SVM on the data matrix, training_dat.dat'
- Separate them into two matched objects, if the images are paired (they're not here).
We'll use the first of these.
train_plus = ismember(dat.metadata_table.Studynumber, wh_studies);
train_minus = ismember(dat.metadata_table.Studynumber, wh_studies);
% generate vector of outcome labels, 1 and -1.
% Cast as double because we need numeric (not logical) inputs.
train_Y = double(train_plus) - double(train_minus);
training_dat.Y = train_Y(wh_subjects);
Test dataset
We'll use the same type of code to define a test dataset, with different studies from the same subdomain:
test_plus = ismember(dat.metadata_table.Studynumber, wh_studies);
test_minus = ismember(dat.metadata_table.Studynumber, wh_studies);
% generate vector of outcome labels, 1 and -1.
% Cast as double because we need numeric (not logical) inputs.
test_Y = double(test_plus) - double(test_minus);
wh_subjects = test_Y ~= 0; % logical vector for any subject in test set
test_dat = remove_empty(test_dat, [], ~wh_subjects);
test_dat.Y = test_Y(wh_subjects);
Questions to answer
- Why is it important to separate training and test datasets?
- Have we done this adequately here? Why or why not?
- What are the advantages, if any, of including different kinds of studies with different subdomains in our training set?
- What are the disadvantages of the above, if any?
- What alternative splits of training and test data could make sense here, and why would they be advantageous?
Train the SVM to discriminate Pain vs. Cognitive Control
Our first goal is to train a classifier that separates pain from cog control:
[cverr, stats, optout] = predict(training_dat, 'cv_svm');
Cross-validated prediction with algorithm cv_svm, 5 folds
Training...training svm kernel linear C=1 optimizer=andre....
Testing...
Done in 1.44 sec
Completed fit for all data in: 0 hours 0 min 2 secs
Training...training svm kernel linear C=1 optimizer=andre....
Testing...
Done in 0.66 sec
Fold 1/5 done in: 0 hours 0 min 1 sec
Training...training svm kernel linear C=1 optimizer=andre....
Testing...
Done in 0.55 sec
Fold 2/5 done in: 0 hours 0 min 1 sec
Training...training svm kernel linear C=1 optimizer=andre....
Testing...
Done in 0.58 sec
Fold 3/5 done in: 0 hours 0 min 1 sec
Training...training svm kernel linear C=1 optimizer=andre....
Testing...
Done in 0.60 sec
Fold 4/5 done in: 0 hours 0 min 1 sec
Training...training svm kernel linear C=1 optimizer=andre....
Testing...
Done in 0.59 sec
Fold 5/5 done in: 0 hours 0 min 1 sec
Total Elapsed Time = 0 hours 0 min 6 sec
Number of unique values in dataset: 201275 Bit rate: 17.62 bits
predict() runs 5-fold cross-validation by default. There are optional inputs for:
- Controlling the cross-validation. You should always leave out ALL images from the same participant together in a cross-validation fold. If you have more than one input image per condition per subject, you will need to customize your holdout set, e.g., with stratified_holdout_set( ) or xval_stratified_holdout_leave_whole_subject_out( ).
- Running a boostrap test to obtain P-values for the SVM weights and identify voxels that make a significant contribution to prediction
Other related CANlab functions are xval_SVM( ), xval_SVR( ), canlab_run_paired_SVM( ), and other functions that start with xval_
paired_d is the cross-validated effect size in paired (forced-choice) comparison. 0.5 is a medium effect, 0.7 a large effect, and anything larger is "very large".
stats contains a stats output structure for the SVM. This includes specific fields:
- stats.paired_hyperplane_dist_scores: The distance from the hyperplane for the paired pos - neg images for each unit (e.g., subject). Positive indicates correct classification.
- stats.paired_d: Simple Cohen's d effect size for forced-choice cross-validated classification. This is the main metric of interest for evaluating classification performance.
- stats.paired_accuracy: Accuracy of forced-choice (two choice) classification, cross-validated
- stats.ROC: Additional output from roc_plot( ), which caculates a Receiver Operating Characteristic curve. This output includes the sensitivity/specificity/PPV (which will all be identical for two-choice paired classification)
- stats.weight_obj: An fmri_data object with the weights. If bootstrapping, a statistic_image object including P-values that can be thresholded.
Display the cross-validated results
Display the unthresholded weight map:
We can extract statistics from the stats structure in different ways.
Confusion matrix. Let's get a confusion matrix, and a table for each true class. The confusion matrix shows the proportions (or counts) of each true class (rows) assigned to each class (columns). It can thus reveal what kinds of errors the model is making.
[m,aprime,corr,far,missclass, mprop, statstable] = confusion_matrix(stats.Y, stats.yfit);
% Create a table object from the matrix mprop, which shows proportions of
% each true class (rows) assigned to each class (columns)
mytable = array2table(mprop, 'VariableNames', {'Pred Cog' 'Pred Pain'}, 'RowNames', {'Actual Cog' 'Actual Pain'});
disp(mytable)
Pred Cog Pred Pain
________ _________
Actual Cog 0.88889 0.11111
Actual Pain 0.088889 0.91111
disp(statstable)
Freq Perc_cases Sens Spec PPV Aprime
____ __________ _______ _______ _______ ______
45 0.5 0.88889 0.91111 0.90909 2.5683
45 0.5 0.91111 0.88889 0.8913 2.5683
The ROC plot. Display the performance in a Receiver Operating Characteristic curve. This plots the Specificity against 1 - Sensitivity (the False Alarm rate). It's possible to have 100% sensitivity by always classifying an observation as a "yes" (1), making 100% false alarms. Likewise, It's possible to have 100% specificity by always classifying an observation as a "no" (-1). The ROC plot displays the tradeoff curve between the two. A concave curve (high Sensitivity with low [1 - Specificity], i.e., high Specificity) shows good classification performance and indicates true predictive signal. A straight line or a convex curve indicates no true predictive signal. ROC_plot() will also choose a threshold that maximizes either the overall accuracy or the balanced accuracy (the average accuracy across classes, even if some classes have fewer observations). The latter is recommended. Here is some of the output of ROC_plot(), saved in a structure:
- Sensitivity: Chances of predicting a "yes" given true "yes"
- Specificity: Chances of predicting a "no" given true "no"
- Positive predictive value: Chances of true "yes" given predicted "yes"
- Negative predictive value: Chances of true "no" given predicted "no"
- d or d_a: Effect size for classification; higher values indicate stronger predictive signal.
- AUC: Area under the ROC curve. Higher values indicate more true signal, with a max of 1 (100% sensitivity with 100% specificity, perfect classification).
- sensitivity_ci, other _ci 95% confidence intervals for sensitivity and other statistics
- Accuracy: Overall accuracy at the selected threshold.
create_figure('ROC')
ans =
Figure (ROC) with properties:
Number: 1
Name: 'ROC'
Color: [1 1 1]
Position: [561 529 560 420]
Units: 'pixels'
Show all properties
ROC = roc_plot(stats.yfit, logical(stats.Y > 0), 'color', 'r');
ROC_PLOT Output: Single-interval, Optimal overall accuracy
Threshold: -1.00 Sens: 91% CI(82%-98%) Spec: 89% CI(79%-98%) PPV: 89% CI(79%-98%) Nonparametric AUC: 0.85 Parametric d_a: 2.64 Accuracy: 90% +- 3.2% (SE), P = 0.000000
Plot the individual subjects. This is always useful to get a better picture of what the data look like!
create_figure('subjects');
plot(stats.dist_from_hyperplane_xval, 'o');
xlabel('Participant'); ylabel('Classifier score');
dat_to_plot = {stats.dist_from_hyperplane_xval(stats.Y > 0) stats.dist_from_hyperplane_xval(stats.Y < 0)};
create_figure('SVM distance scores by class');
barplot_columns(dat_to_plot, 'title', 'Training data SVM scores', 'nofigure', 'colors', {[.7 .2 .2] [.5 .3 .7]});
Column 1: Column 2:
---------------------------------------------
Tests of column means against zero
---------------------------------------------
Name Mean_Value Std_Error T P Cohens_d
___________ __________ _________ _______ __________ ________
{'Col 1'} 0.61095 0.060264 10.138 4.3724e-13 1.5113
{'Col 2'} -0.56656 0.068587 -8.2604 1.7104e-10 -1.2314
set(gca, 'XTickLabel', {'Pain' 'Cog Control'});
ylabel('Cross-validated SVM distance');
Questions to answer
- How accurate is the model? How sensitive and specific? Describe in words what these values mean.
- What is the SVM classificaiton boundary here for separating pain from cognitive control? If I have an observation with an SVM distance score of 0.5, for example, what class would the model predict? How about a score of -0.7?
- How many subjects are misclassified here? Do they come more from some subdomains than others? (Note: you will want to create plots to show this that are not in the code example!)
Display the SVM map and results
Display the unthresholded weight map:
create_figure('weight map'); axis off;
montage(stats.weight_obj);
Setting up fmridisplay objects
sagittal montage: 4038 voxels displayed, 197237 not displayed on these slices
sagittal montage: 4000 voxels displayed, 197275 not displayed on these slices
sagittal montage: 3849 voxels displayed, 197426 not displayed on these slices
axial montage: 28960 voxels displayed, 172315 not displayed on these slices
axial montage: 31113 voxels displayed, 170162 not displayed on these slices