Machine Learning Workflow: SVM with RBF Kernel (Heart Data)

Biostat 212A

Author

Dr. Jin Zhou @ UCLA

Published

February 23, 2026

Display system information for reproducibility.

sessionInfo()
R version 4.5.2 (2025-10-31)
Platform: aarch64-apple-darwin20
Running under: macOS Sequoia 15.7.4

Matrix products: default
BLAS:   /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib 
LAPACK: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRlapack.dylib;  LAPACK version 3.12.1

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

time zone: America/Los_Angeles
tzcode source: internal

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

loaded via a namespace (and not attached):
 [1] htmlwidgets_1.6.4 compiler_4.5.2    fastmap_1.2.0     cli_3.6.5        
 [5] tools_4.5.2       htmltools_0.5.8.1 rstudioapi_0.17.1 yaml_2.3.10      
 [9] rmarkdown_2.29    knitr_1.50        jsonlite_2.0.0    xfun_0.56        
[13] digest_0.6.37     rlang_1.1.6       evaluate_1.0.5   
import IPython
print(IPython.sys_info())
{'commit_hash': '5ed988a91',
 'commit_source': 'installation',
 'default_encoding': 'utf-8',
 'ipython_path': '/Users/jinjinzhou/.virtualenvs/r-tensorflow/lib/python3.10/site-packages/IPython',
 'ipython_version': '8.33.0',
 'os_name': 'posix',
 'platform': 'macOS-15.7.4-arm64-arm-64bit',
 'sys_executable': '/Users/jinjinzhou/.virtualenvs/r-tensorflow/bin/python',
 'sys_platform': 'darwin',
 'sys_version': '3.10.16 (main, Mar  3 2025, 20:01:33) [Clang 16.0.0 '
                '(clang-1600.0.26.6)]'}

1 Overview

We illustrate the typical machine learning workflow for support vector machines using the Heart data set. The outcome is AHD (Yes or No).

  1. Initial splitting to test and non-test sets.

  2. Pre-processing of data: dummy coding categorical variables, standardizing numerical variables, imputing missing values, …

  3. Tune the SVM algorithm using 5-fold cross-validation (CV) on the non-test data.

  4. Choose the best model by CV and refit it on the whole non-test data.

  5. Final classification on the test data.

2 Heart data

The goal is to predict the binary outcome AHD (Yes or No) of patients.

# Load libraries
library(GGally)
library(gtsummary)
library(kernlab)
library(tidyverse)
library(tidymodels)

# Load the `Heart.csv` data.
Heart <- read_csv("../data/Heart.csv") %>% 
  # first column is patient ID, which we don't need
  select(-1) %>%
  # RestECG is categorical with value 0, 1, 2
  mutate(RestECG = as.character(RestECG)) %>%
  print(width = Inf)
# A tibble: 303 × 14
     Age   Sex ChestPain    RestBP  Chol   Fbs RestECG MaxHR ExAng Oldpeak Slope
   <dbl> <dbl> <chr>         <dbl> <dbl> <dbl> <chr>   <dbl> <dbl>   <dbl> <dbl>
 1    63     1 typical         145   233     1 2         150     0     2.3     3
 2    67     1 asymptomatic    160   286     0 2         108     1     1.5     2
 3    67     1 asymptomatic    120   229     0 2         129     1     2.6     2
 4    37     1 nonanginal      130   250     0 0         187     0     3.5     3
 5    41     0 nontypical      130   204     0 2         172     0     1.4     1
 6    56     1 nontypical      120   236     0 0         178     0     0.8     1
 7    62     0 asymptomatic    140   268     0 2         160     0     3.6     3
 8    57     0 asymptomatic    120   354     0 0         163     1     0.6     1
 9    63     1 asymptomatic    130   254     0 2         147     0     1.4     2
10    53     1 asymptomatic    140   203     1 2         155     1     3.1     3
      Ca Thal       AHD  
   <dbl> <chr>      <chr>
 1     0 fixed      No   
 2     3 normal     Yes  
 3     2 reversable Yes  
 4     0 normal     No   
 5     0 normal     No   
 6     0 normal     No   
 7     2 normal     Yes  
 8     0 normal     No   
 9     1 reversable Yes  
10     0 reversable Yes  
# ℹ 293 more rows
# Numerical summaries stratified by the outcome `AHD`.
Heart %>% tbl_summary(by = AHD)
Characteristic No
N = 1641
Yes
N = 1391
Age 52 (45, 59) 58 (52, 62)
Sex 92 (56%) 114 (82%)
ChestPain

    asymptomatic 39 (24%) 105 (76%)
    nonanginal 68 (41%) 18 (13%)
    nontypical 41 (25%) 9 (6.5%)
    typical 16 (9.8%) 7 (5.0%)
RestBP 130 (120, 140) 130 (120, 145)
Chol 235 (209, 268) 249 (217, 284)
Fbs 23 (14%) 22 (16%)
RestECG

    0 95 (58%) 56 (40%)
    1 1 (0.6%) 3 (2.2%)
    2 68 (41%) 80 (58%)
MaxHR 161 (149, 172) 142 (125, 157)
ExAng 23 (14%) 76 (55%)
Oldpeak 0.20 (0.00, 1.05) 1.40 (0.50, 2.50)
Slope

    1 106 (65%) 36 (26%)
    2 49 (30%) 91 (65%)
    3 9 (5.5%) 12 (8.6%)
Ca

    0 130 (81%) 46 (33%)
    1 21 (13%) 44 (32%)
    2 7 (4.3%) 31 (22%)
    3 3 (1.9%) 17 (12%)
    Unknown 3 1
Thal

    fixed 6 (3.7%) 12 (8.7%)
    normal 129 (79%) 37 (27%)
    reversable 28 (17%) 89 (64%)
    Unknown 1 1
1 Median (Q1, Q3); n (%)
# Graphical summary:
# Heart %>% ggpairs()
# Load the pandas library
import pandas as pd
# Load numpy for array manipulation
import numpy as np
# Load seaborn plotting library
import seaborn as sns
import matplotlib.pyplot as plt

# Set font sizes in plots
sns.set(font_scale = 1.2)
# Display all columns
pd.set_option('display.max_columns', None)

Heart = pd.read_csv("../data/Heart.csv")
Heart
     Unnamed: 0  Age  Sex     ChestPain  RestBP  Chol  Fbs  RestECG  MaxHR  \
0             1   63    1       typical     145   233    1        2    150   
1             2   67    1  asymptomatic     160   286    0        2    108   
2             3   67    1  asymptomatic     120   229    0        2    129   
3             4   37    1    nonanginal     130   250    0        0    187   
4             5   41    0    nontypical     130   204    0        2    172   
..          ...  ...  ...           ...     ...   ...  ...      ...    ...   
298         299   45    1       typical     110   264    0        0    132   
299         300   68    1  asymptomatic     144   193    1        0    141   
300         301   57    1  asymptomatic     130   131    0        0    115   
301         302   57    0    nontypical     130   236    0        2    174   
302         303   38    1    nonanginal     138   175    0        0    173   

     ExAng  Oldpeak  Slope   Ca        Thal  AHD  
0        0      2.3      3  0.0       fixed   No  
1        1      1.5      2  3.0      normal  Yes  
2        1      2.6      2  2.0  reversable  Yes  
3        0      3.5      3  0.0      normal   No  
4        0      1.4      1  0.0      normal   No  
..     ...      ...    ...  ...         ...  ...  
298      0      1.2      2  0.0  reversable  Yes  
299      0      3.4      2  2.0  reversable  Yes  
300      1      1.2      2  1.0  reversable  Yes  
301      0      0.0      2  1.0      normal  Yes  
302      0      0.0      1  NaN      normal   No  

[303 rows x 15 columns]
# Numerical summaries
Heart.describe(include = 'all')
        Unnamed: 0         Age         Sex     ChestPain      RestBP  \
count   303.000000  303.000000  303.000000           303  303.000000   
unique         NaN         NaN         NaN             4         NaN   
top            NaN         NaN         NaN  asymptomatic         NaN   
freq           NaN         NaN         NaN           144         NaN   
mean    152.000000   54.438944    0.679868           NaN  131.689769   
std      87.612784    9.038662    0.467299           NaN   17.599748   
min       1.000000   29.000000    0.000000           NaN   94.000000   
25%      76.500000   48.000000    0.000000           NaN  120.000000   
50%     152.000000   56.000000    1.000000           NaN  130.000000   
75%     227.500000   61.000000    1.000000           NaN  140.000000   
max     303.000000   77.000000    1.000000           NaN  200.000000   

              Chol         Fbs     RestECG       MaxHR       ExAng  \
count   303.000000  303.000000  303.000000  303.000000  303.000000   
unique         NaN         NaN         NaN         NaN         NaN   
top            NaN         NaN         NaN         NaN         NaN   
freq           NaN         NaN         NaN         NaN         NaN   
mean    246.693069    0.148515    0.990099  149.607261    0.326733   
std      51.776918    0.356198    0.994971   22.875003    0.469794   
min     126.000000    0.000000    0.000000   71.000000    0.000000   
25%     211.000000    0.000000    0.000000  133.500000    0.000000   
50%     241.000000    0.000000    1.000000  153.000000    0.000000   
75%     275.000000    0.000000    2.000000  166.000000    1.000000   
max     564.000000    1.000000    2.000000  202.000000    1.000000   

           Oldpeak       Slope          Ca    Thal  AHD  
count   303.000000  303.000000  299.000000     301  303  
unique         NaN         NaN         NaN       3    2  
top            NaN         NaN         NaN  normal   No  
freq           NaN         NaN         NaN     166  164  
mean      1.039604    1.600660    0.672241     NaN  NaN  
std       1.161075    0.616226    0.937438     NaN  NaN  
min       0.000000    1.000000    0.000000     NaN  NaN  
25%       0.000000    1.000000    0.000000     NaN  NaN  
50%       0.800000    2.000000    0.000000     NaN  NaN  
75%       1.600000    2.000000    1.000000     NaN  NaN  
max       6.200000    3.000000    3.000000     NaN  NaN  

Graphical summary:

# Graphical summaries
plt.figure()
sns.pairplot(data = Heart);
plt.show()

3 Initial split into test and non-test sets

We randomly split the data into 25% test data and 75% non-test data. Stratify on AHD.

# For reproducibility
set.seed(203)

data_split <- initial_split(
  Heart, 
  # stratify by AHD
  strata = "AHD", 
  prop = 0.75
  )
data_split
<Training/Testing/Total>
<227/76/303>
Heart_other <- training(data_split)
dim(Heart_other)
[1] 227  14
Heart_test <- testing(data_split)
dim(Heart_test)
[1] 76 14
from sklearn.model_selection import train_test_split

Heart_other, Heart_test = train_test_split(
  Heart, 
  train_size = 0.75,
  random_state = 425, # seed
  stratify = Heart.AHD
  )
Heart_test.shape
(76, 15)
Heart_other.shape
(227, 15)

Separate \(X\) and \(y\). We will use 13 features.

num_features = ['Age', 'Sex', 'RestBP', 'Chol', 'Fbs', 'RestECG', 'MaxHR', 'ExAng', 'Oldpeak', 'Slope', 'Ca']
cat_features = ['ChestPain', 'Thal']
features = np.concatenate([num_features, cat_features])
# Non-test X and y
X_other = Heart_other[features]
y_other = Heart_other.AHD == 'Yes'
# Test X and y
X_test = Heart_test[features]
y_test = Heart_test.AHD == 'Yes'

4 Recipe (R) and Preprocessing (Python)

  • A data dictionary (roughly) is at https://keras.io/examples/structured_data/structured_data_classification_with_feature_space/.

  • We have following features:

    • Numerical features: Age, RestBP, Chol, Slope (1, 2 or 3), MaxHR, ExAng, Oldpeak, Ca (0, 1, 2 or 3).

    • Categorical features coded as integer: Sex (0 or 1), Fbs (0 or 1), RestECG (0, 1 or 2).

    • Categorical features coded as string: ChestPain, Thal.

  • There are missing values in Ca and Thal. Since missing proportion is not high, we will use simple mean (for numerical feature Ca) and mode (for categorical feature Thal) imputation.

svm_recipe <- 
  recipe(
    AHD ~ ., 
    data = Heart_other
  ) %>%
  # mean imputation for Ca
  step_impute_mean(Ca) %>%
  # mode imputation for Thal
  step_impute_mode(Thal) %>%
  # create traditional dummy variables (necessary for svm)
  step_dummy(all_nominal_predictors()) %>%
  # zero-variance filter
  step_zv(all_numeric_predictors()) %>% 
  # center and scale numeric data
  step_normalize(all_numeric_predictors()) # %>%
  # estimate the means and standard deviations
 #prep(training = Heart_other, retain = TRUE)
svm_recipe

There are missing values in Ca (quantitative) and Thal (qualitative) variables. We are going to use simple mean imputation for Ca and most_frequent imputation for Thal. This is suboptimal. Better strategy is to use multiple imputation.

# How many NaNs
Heart.isna().sum()
Unnamed: 0    0
Age           0
Sex           0
ChestPain     0
RestBP        0
Chol          0
Fbs           0
RestECG       0
MaxHR         0
ExAng         0
Oldpeak       0
Slope         0
Ca            4
Thal          2
AHD           0
dtype: int64
from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline

# Transformer for categorical variables
categorical_tf = Pipeline(steps = [
  ("cat_impute", SimpleImputer(strategy = 'most_frequent')),
  ("encoder", OneHotEncoder(drop = 'first'))
])

# Transformer for continuous variables
numeric_tf = Pipeline(steps = [
  ("num_impute", SimpleImputer(strategy = 'mean')),
])

# Column transformer
col_tf = ColumnTransformer(transformers = [
  ('num', numeric_tf, num_features),
  ('cat', categorical_tf, cat_features)
])

5 Model

svm_mod <- 
  svm_rbf(
    mode = "classification",
    cost = tune(),
    rbf_sigma = tune()
  ) %>% 
  set_engine("kernlab")
svm_mod
Radial Basis Function Support Vector Machine Model Specification (classification)

Main Arguments:
  cost = tune()
  rbf_sigma = tune()

Computational engine: kernlab 
from sklearn.svm import SVC

svm_mod = SVC(
  C = 1.0,
  kernel = 'rbf',
  gamma = 'scale', # 1 / (n_features * X.var())
  probability = True,
  random_state = 425
  )

6 Workflow in R and pipeline in Python

Here we bundle the preprocessing step (Python) or recipe (R) and model.

svm_wf <- workflow() %>%
  add_recipe(svm_recipe) %>%
  add_model(svm_mod)
svm_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: svm_rbf()

── Preprocessor ────────────────────────────────────────────────────────────────
5 Recipe Steps

• step_impute_mean()
• step_impute_mode()
• step_dummy()
• step_zv()
• step_normalize()

── Model ───────────────────────────────────────────────────────────────────────
Radial Basis Function Support Vector Machine Model Specification (classification)

Main Arguments:
  cost = tune()
  rbf_sigma = tune()

Computational engine: kernlab 
from sklearn.pipeline import Pipeline

pipe = Pipeline(steps = [
  ("col_tf", col_tf),
  ("model", svm_mod)
  ])
pipe
Pipeline(steps=[('col_tf',
                 ColumnTransformer(transformers=[('num',
                                                  Pipeline(steps=[('num_impute',
                                                                   SimpleImputer())]),
                                                  ['Age', 'Sex', 'RestBP',
                                                   'Chol', 'Fbs', 'RestECG',
                                                   'MaxHR', 'ExAng', 'Oldpeak',
                                                   'Slope', 'Ca']),
                                                 ('cat',
                                                  Pipeline(steps=[('cat_impute',
                                                                   SimpleImputer(strategy='most_frequent')),
                                                                  ('encoder',
                                                                   OneHotEncoder(drop='first'))]),
                                                  ['ChestPain', 'Thal'])])),
                ('model', SVC(probability=True, random_state=425))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

7 Tuning grid

Here we tune the cost and radial scale rbf_sigma.

param_grid <- grid_regular(
  cost(range = c(-8, 5)),
  rbf_sigma(range = c(-5, -3)),
  levels = c(14, 5)
  )
param_grid
# A tibble: 70 × 2
      cost rbf_sigma
     <dbl>     <dbl>
 1 0.00391   0.00001
 2 0.00781   0.00001
 3 0.0156    0.00001
 4 0.0312    0.00001
 5 0.0625    0.00001
 6 0.125     0.00001
 7 0.25      0.00001
 8 0.5       0.00001
 9 1         0.00001
10 2         0.00001
# ℹ 60 more rows

Here we tune the inverse penalty strength C and the mixture parameter l1_ratio.

# Tune hyper-parameter(s)
max_depth_grid = [1, 2, 3]
C_grid = [0.5, 1.0, 1.5, 2.0, 2.5, 3.0]
gamma_grid = np.logspace(start = -8, stop = -2, base = 10, num = 10)
tuned_parameters = {
  "model__C": C_grid,
  "model__gamma": gamma_grid
  }
tuned_parameters
{'model__C': [0.5, 1.0, 1.5, 2.0, 2.5, 3.0], 'model__gamma': array([1.00000000e-08, 4.64158883e-08, 2.15443469e-07, 1.00000000e-06,
       4.64158883e-06, 2.15443469e-05, 1.00000000e-04, 4.64158883e-04,
       2.15443469e-03, 1.00000000e-02])}

8 Cross-validation (CV)

Set cross-validation partitions.

set.seed(203)

folds <- vfold_cv(Heart_other, v = 5)
folds
#  5-fold cross-validation 
# A tibble: 5 × 2
  splits           id   
  <list>           <chr>
1 <split [181/46]> Fold1
2 <split [181/46]> Fold2
3 <split [182/45]> Fold3
4 <split [182/45]> Fold4
5 <split [182/45]> Fold5

Fit cross-validation.

svm_fit <- svm_wf %>%
  tune_grid(
    resamples = folds,
    grid = param_grid,
    metrics = metric_set(roc_auc, accuracy)
    )
maximum number of iterations reached 0.0006167026 0.0006166884maximum number of iterations reached 0.00121581 0.001215753maximum number of iterations reached 0.002413856 0.002413497maximum number of iterations reached 0.004640132 0.004637082maximum number of iterations reached 0.007885466 0.007863731maximum number of iterations reached 0.009407873 0.009355684maximum number of iterations reached 0.005439544 0.005411307maximum number of iterations reached 0.0006918681 0.0006882373maximum number of iterations reached 0.001912496 0.001912297maximum number of iterations reached 0.003711727 0.00371027maximum number of iterations reached 0.006730444 0.006718733maximum number of iterations reached 0.009539167 0.009505877maximum number of iterations reached 0.007186948 0.007141899maximum number of iterations reached 0.001991726 0.001977949maximum number of iterations reached 1.651075e-05 1.642477e-05maximum number of iterations reached 0.005653891 0.005647517maximum number of iterations reached 0.008941257 0.008908348maximum number of iterations reached 0.008048185 0.007986398maximum number of iterations reached 0.003363892 0.003344607maximum number of iterations reached 0.000114329 0.0001137545maximum number of iterations reached 0.009042471 0.008983721maximum number of iterations reached 0.005567267 0.00553286maximum number of iterations reached 0.0008671175 0.0008614095maximum number of iterations reached 0.002633736 0.002614116maximum number of iterations reached 3.975029e-05 3.954652e-05maximum number of iterations reached 0.0007162173 0.000716193maximum number of iterations reached 0.001461537 0.001461432maximum number of iterations reached 0.002857098 0.002856634maximum number of iterations reached 0.005403931 0.00540116maximum number of iterations reached 0.008656764 0.00864584maximum number of iterations reached 0.008396898 0.008393454maximum number of iterations reached 0.004335617 0.004316632maximum number of iterations reached 0.0004611638 0.0004583121maximum number of iterations reached 0.002262709 0.002262455maximum number of iterations reached 0.004463091 0.004461539maximum number of iterations reached 0.007582508 0.007576252maximum number of iterations reached 0.009519261 0.009509828maximum number of iterations reached 0.006172954 0.006164909maximum number of iterations reached 0.001338647 0.001330074maximum number of iterations reached 9.661847e-06 9.596178e-06maximum number of iterations reached 0.006377406 0.006372946maximum number of iterations reached 0.008354128 0.00832515maximum number of iterations reached 0.00760811 0.007600388maximum number of iterations reached 0.002869127 0.002848519maximum number of iterations reached 9.273709e-05 9.220452e-05maximum number of iterations reached 0.009104934 0.00910128maximum number of iterations reached 0.004620295 0.004600451maximum number of iterations reached 0.0007295661 0.0007255932maximum number of iterations reached 0.001830654 0.001822865maximum number of iterations reached 1.640333e-05 1.629688e-05maximum number of iterations reached 0.0007062914 0.0007062779maximum number of iterations reached 0.001379405 0.001379327maximum number of iterations reached 0.002729575 0.002729153maximum number of iterations reached 0.00528297 0.005279097maximum number of iterations reached 0.008713213 0.008686546maximum number of iterations reached 0.009947884 0.009879146maximum number of iterations reached 0.005557024 0.005513696maximum number of iterations reached 0.001091534 0.001083861maximum number of iterations reached 0.002174882 0.002174646maximum number of iterations reached 0.004279661 0.004277818maximum number of iterations reached 0.007599785 0.0075845maximum number of iterations reached 0.009730095 0.009693879maximum number of iterations reached 0.007071886 0.00701642maximum number of iterations reached 0.002584113 0.002560282maximum number of iterations reached 4.355537e-05 4.329778e-05maximum number of iterations reached 0.006303921 0.006295361maximum number of iterations reached 0.009365524 0.009327874maximum number of iterations reached 0.008838694 0.008800175maximum number of iterations reached 0.003983993 0.003944394maximum number of iterations reached 0.0004271855 0.0004244608maximum number of iterations reached 0.009616821 0.009584814maximum number of iterations reached 0.005278354 0.005257543maximum number of iterations reached 0.001367806 0.001357797maximum number of iterations reached 0.002402776 0.002382682maximum number of iterations reached 6.788187e-05 6.747298e-05maximum number of iterations reached 0.0006144133 0.0006143972maximum number of iterations reached 0.001205367 0.001205298maximum number of iterations reached 0.002463121 0.002462761maximum number of iterations reached 0.004597533 0.004595089maximum number of iterations reached 0.007879429 0.007864221maximum number of iterations reached 0.009057868 0.009048576maximum number of iterations reached 0.005025375 0.005013801maximum number of iterations reached 0.000851767 0.0008446216maximum number of iterations reached 0.001949919 0.001949714maximum number of iterations reached 0.003682315 0.003681051maximum number of iterations reached 0.006617101 0.006607957maximum number of iterations reached 0.009053747 0.009024161maximum number of iterations reached 0.006804911 0.006779083maximum number of iterations reached 0.001916203 0.001905145maximum number of iterations reached 2.0517e-05 2.039816e-05maximum number of iterations reached 0.00553809 0.005533448maximum number of iterations reached 0.008723184 0.008700886maximum number of iterations reached 0.008248697 0.008225651maximum number of iterations reached 0.003625008 0.003592449maximum number of iterations reached 0.0001872509 0.0001861748maximum number of iterations reached 0.009154901 0.009124659maximum number of iterations reached 0.004755589 0.004732966maximum number of iterations reached 0.0008476 0.0008433924maximum number of iterations reached 0.001971387 0.001960957maximum number of iterations reached 4.878505e-05 4.85096e-05maximum number of iterations reached 0.0006890083 0.0006889928maximum number of iterations reached 0.00140054 0.001400479maximum number of iterations reached 0.002744557 0.002744058maximum number of iterations reached 0.005092967 0.005088765maximum number of iterations reached 0.008552635 0.008526411maximum number of iterations reached 0.009178008 0.009176949maximum number of iterations reached 0.004486167 0.004473461maximum number of iterations reached 0.0006504771 0.0006460644maximum number of iterations reached 0.002165759 0.002165519maximum number of iterations reached 0.004139412 0.004137417maximum number of iterations reached 0.007421394 0.007407485maximum number of iterations reached 0.009679433 0.009666707maximum number of iterations reached 0.006324971 0.006317229maximum number of iterations reached 0.001247646 0.001243186maximum number of iterations reached 0.006201746 0.00619438maximum number of iterations reached 0.009125879 0.009098703maximum number of iterations reached 0.008305289 0.008302499maximum number of iterations reached 0.003288412 0.003280519maximum number of iterations reached 9.63493e-05 9.59529e-05maximum number of iterations reached 0.009451518 0.0094317maximum number of iterations reached 0.004854992 0.004834348maximum number of iterations reached 0.0005733583 0.0005704145maximum number of iterations reached 0.001647564 0.001643279maximum number of iterations reached 3.25139e-05 3.234804e-05
svm_fit
# Tuning results
# 5-fold cross-validation 
# A tibble: 5 × 4
  splits           id    .metrics           .notes          
  <list>           <chr> <list>             <list>          
1 <split [181/46]> Fold1 <tibble [140 × 6]> <tibble [0 × 4]>
2 <split [181/46]> Fold2 <tibble [140 × 6]> <tibble [0 × 4]>
3 <split [182/45]> Fold3 <tibble [140 × 6]> <tibble [0 × 4]>
4 <split [182/45]> Fold4 <tibble [140 × 6]> <tibble [0 × 4]>
5 <split [182/45]> Fold5 <tibble [140 × 6]> <tibble [0 × 4]>

Visualize CV results:

svm_fit %>%
  collect_metrics() %>%
  print(width = Inf) %>%
  filter(.metric == "roc_auc") %>%
  ggplot(mapping = aes(x = cost, y = mean, alpha = rbf_sigma)) +
  geom_point() +
  geom_line(aes(group = rbf_sigma)) +
  labs(x = "Cost", y = "CV AUC") +
  scale_x_log10() 
# A tibble: 140 × 8
      cost rbf_sigma .metric  .estimator  mean     n std_err .config         
     <dbl>     <dbl> <chr>    <chr>      <dbl> <int>   <dbl> <chr>           
 1 0.00391 0.00001   accuracy binary     0.542     5  0.0242 pre0_mod01_post0
 2 0.00391 0.00001   roc_auc  binary     0.931     5  0.0155 pre0_mod01_post0
 3 0.00391 0.0000316 accuracy binary     0.542     5  0.0242 pre0_mod02_post0
 4 0.00391 0.0000316 roc_auc  binary     0.931     5  0.0155 pre0_mod02_post0
 5 0.00391 0.0001    accuracy binary     0.542     5  0.0242 pre0_mod03_post0
 6 0.00391 0.0001    roc_auc  binary     0.931     5  0.0155 pre0_mod03_post0
 7 0.00391 0.000316  accuracy binary     0.542     5  0.0242 pre0_mod04_post0
 8 0.00391 0.000316  roc_auc  binary     0.931     5  0.0155 pre0_mod04_post0
 9 0.00391 0.001     accuracy binary     0.542     5  0.0242 pre0_mod05_post0
10 0.00391 0.001     roc_auc  binary     0.934     5  0.0151 pre0_mod05_post0
# ℹ 130 more rows

Show the top 5 models.

svm_fit %>%
  show_best(metric = "roc_auc")
# A tibble: 5 × 8
    cost rbf_sigma .metric .estimator  mean     n std_err .config         
   <dbl>     <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>           
1 0.0625     0.001 roc_auc binary     0.936     5  0.0153 pre0_mod25_post0
2 0.125      0.001 roc_auc binary     0.936     5  0.0153 pre0_mod30_post0
3 0.25       0.001 roc_auc binary     0.936     5  0.0153 pre0_mod35_post0
4 0.5        0.001 roc_auc binary     0.936     5  0.0153 pre0_mod40_post0
5 0.0312     0.001 roc_auc binary     0.935     5  0.0151 pre0_mod20_post0

Let’s select the best model.

best_svm <- svm_fit %>%
  select_best(metric = "roc_auc")
best_svm
# A tibble: 1 × 3
    cost rbf_sigma .config         
   <dbl>     <dbl> <chr>           
1 0.0625     0.001 pre0_mod25_post0

Set up CV partitions and CV criterion.

from sklearn.model_selection import GridSearchCV

# Set up CV
n_folds = 5
search = GridSearchCV(
  pipe,
  tuned_parameters,
  cv = n_folds, 
  scoring = "roc_auc",
  # Refit the best model on the whole data set
  refit = True
  )

Fit CV. This is typically the most time-consuming step.

# Fit CV
search.fit(X_other, y_other)
GridSearchCV(cv=5,
             estimator=Pipeline(steps=[('col_tf',
                                        ColumnTransformer(transformers=[('num',
                                                                         Pipeline(steps=[('num_impute',
                                                                                          SimpleImputer())]),
                                                                         ['Age',
                                                                          'Sex',
                                                                          'RestBP',
                                                                          'Chol',
                                                                          'Fbs',
                                                                          'RestECG',
                                                                          'MaxHR',
                                                                          'ExAng',
                                                                          'Oldpeak',
                                                                          'Slope',
                                                                          'Ca']),
                                                                        ('cat',
                                                                         Pipeline(steps=[('cat_impute',
                                                                                          SimpleImputer(strategy='most_frequent')),
                                                                                         ('encoder',
                                                                                          OneHotEncoder(drop='first'))]),
                                                                         ['ChestPain',
                                                                          'Thal'])])),
                                       ('model',
                                        SVC(probability=True,
                                            random_state=425))]),
             param_grid={'model__C': [0.5, 1.0, 1.5, 2.0, 2.5, 3.0],
                         'model__gamma': array([1.00000000e-08, 4.64158883e-08, 2.15443469e-07, 1.00000000e-06,
       4.64158883e-06, 2.15443469e-05, 1.00000000e-04, 4.64158883e-04,
       2.15443469e-03, 1.00000000e-02])},
             scoring='roc_auc')
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Visualize CV results.

Code
cv_res = pd.DataFrame({
  "C": np.array(search.cv_results_["param_model__C"]),
  "auc": search.cv_results_["mean_test_score"],
  "gamma": search.cv_results_["param_model__gamma"]
  })

plt.figure()
sns.relplot(
  # kind = "line",
  data = cv_res,
  x = "gamma",
  y = "auc",
  hue = "C"
  ).set(
    xscale = "log",
    xlabel = "gamma",
    ylabel = "CV AUC"
    );
plt.show()

Best CV AUC:

search.best_score_
np.float64(0.7295460317460316)

The training accuracy is

from sklearn.metrics import accuracy_score, roc_auc_score

accuracy_score(
  y_other,
  search.best_estimator_.predict(X_other)
  )
0.7533039647577092

9 Finalize our model

Now we are done tuning. Finally, let’s fit this final model to the whole training data and use our test data to estimate the model performance we expect to see with new data.

# Final workflow
final_wf <- svm_wf %>%
  finalize_workflow(best_svm)
final_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: svm_rbf()

── Preprocessor ────────────────────────────────────────────────────────────────
5 Recipe Steps

• step_impute_mean()
• step_impute_mode()
• step_dummy()
• step_zv()
• step_normalize()

── Model ───────────────────────────────────────────────────────────────────────
Radial Basis Function Support Vector Machine Model Specification (classification)

Main Arguments:
  cost = 0.0625
  rbf_sigma = 0.001

Computational engine: kernlab 
# Fit the whole training set, then predict the test cases
final_fit <- 
  final_wf %>%
  last_fit(data_split)
final_fit
# Resampling results
# Manual resampling 
# A tibble: 1 × 6
  splits           id               .metrics .notes   .predictions .workflow 
  <list>           <chr>            <list>   <list>   <list>       <list>    
1 <split [227/76]> train/test split <tibble> <tibble> <tibble>     <workflow>
# Test metrics
final_fit %>% 
  collect_metrics()
# A tibble: 3 × 4
  .metric     .estimator .estimate .config        
  <chr>       <chr>          <dbl> <chr>          
1 accuracy    binary         0.539 pre0_mod0_post0
2 roc_auc     binary         0.843 pre0_mod0_post0
3 brier_class binary         0.198 pre0_mod0_post0

Since we called GridSearchCV with refit = True, the best model fit on the whole non-test data is readily available.

search.best_estimator_
Pipeline(steps=[('col_tf',
                 ColumnTransformer(transformers=[('num',
                                                  Pipeline(steps=[('num_impute',
                                                                   SimpleImputer())]),
                                                  ['Age', 'Sex', 'RestBP',
                                                   'Chol', 'Fbs', 'RestECG',
                                                   'MaxHR', 'ExAng', 'Oldpeak',
                                                   'Slope', 'Ca']),
                                                 ('cat',
                                                  Pipeline(steps=[('cat_impute',
                                                                   SimpleImputer(strategy='most_frequent')),
                                                                  ('encoder',
                                                                   OneHotEncoder(drop='first'))]),
                                                  ['ChestPain', 'Thal'])])),
                ('model',
                 SVC(C=3.0, gamma=np.float64(0.00046415888336127724),
                     probability=True, random_state=425))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

The final AUC on the test set is

roc_auc_score(
  y_test,
  search.best_estimator_.predict_proba(X_test)[:, 1]
  )
np.float64(0.689198606271777)

The final classification accuracy on the test set is

accuracy_score(
  y_test,
  search.best_estimator_.predict(X_test)
  )
0.6052631578947368