Caret package cross-validation summary in R

75 views Asked by At

Assume i have a K-folds list with K=10, each element contains caret classification performance output:

dput(transformed_conf_matrices$Fold01)
structure(c(1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 0.333333333333333, 
0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 
0.333333333333333, 0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
0.333333333333333, 0.333333333333333, 0.333333333333333, 1), dim = c(11L, 
3L), dimnames = list(c("Sensitivity", "Specificity", "Pos Pred Value", 
"Neg Pred Value", "Precision", "Recall", "F1", "Prevalence", 
"Detection Rate", "Detection Prevalence", "Balanced Accuracy"
), c("Class: setosa", "Class: versicolor", "Class: virginica"
)))



transformed_conf_matrices$Fold01
                     Class: setosa Class: versicolor Class: virginica
Sensitivity              1.0000000         1.0000000        1.0000000
Specificity              1.0000000         1.0000000        1.0000000
Pos Pred Value           1.0000000         1.0000000        1.0000000
Neg Pred Value           1.0000000         1.0000000        1.0000000
Precision                1.0000000         1.0000000        1.0000000
Recall                   1.0000000         1.0000000        1.0000000
F1                       1.0000000         1.0000000        1.0000000
Prevalence               0.3333333         0.3333333        0.3333333
Detection Rate           0.3333333         0.3333333        0.3333333
Detection Prevalence     0.3333333         0.3333333        0.3333333
Balanced Accuracy        1.0000000         1.0000000        1.0000000

In this special case , transformed_conf_matrices$Fold01 to transformed_conf_matrices$Fold10 are equal ( same values ).

I would like to have the mean and variance of those metrics. I did many attempts with lapply without success.

The K-folds list :

dput(transformed_conf_matrices)
list(Fold01 = structure(c(1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 
0.333333333333333, 0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
0.333333333333333, 0.333333333333333, 0.333333333333333, 1, 1, 
1, 1, 1, 1, 1, 1, 0.333333333333333, 0.333333333333333, 0.333333333333333, 
1), dim = c(11L, 3L), dimnames = list(c("Sensitivity", "Specificity", 
"Pos Pred Value", "Neg Pred Value", "Precision", "Recall", "F1", 
"Prevalence", "Detection Rate", "Detection Prevalence", "Balanced Accuracy"
), c("Class: setosa", "Class: versicolor", "Class: virginica"
))), Fold02 = structure(c(1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 
0.333333333333333, 0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
0.333333333333333, 0.333333333333333, 0.333333333333333, 1, 1, 
1, 1, 1, 1, 1, 1, 0.333333333333333, 0.333333333333333, 0.333333333333333, 
1), dim = c(11L, 3L), dimnames = list(c("Sensitivity", "Specificity", 
"Pos Pred Value", "Neg Pred Value", "Precision", "Recall", "F1", 
"Prevalence", "Detection Rate", "Detection Prevalence", "Balanced Accuracy"
), c("Class: setosa", "Class: versicolor", "Class: virginica"
))), Fold03 = structure(c(1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 
0.333333333333333, 0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
0.333333333333333, 0.333333333333333, 0.333333333333333, 1, 1, 
1, 1, 1, 1, 1, 1, 0.333333333333333, 0.333333333333333, 0.333333333333333, 
1), dim = c(11L, 3L), dimnames = list(c("Sensitivity", "Specificity", 
"Pos Pred Value", "Neg Pred Value", "Precision", "Recall", "F1", 
"Prevalence", "Detection Rate", "Detection Prevalence", "Balanced Accuracy"
), c("Class: setosa", "Class: versicolor", "Class: virginica"
))), Fold04 = structure(c(1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 
0.333333333333333, 0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
0.333333333333333, 0.333333333333333, 0.333333333333333, 1, 1, 
1, 1, 1, 1, 1, 1, 0.333333333333333, 0.333333333333333, 0.333333333333333, 
1), dim = c(11L, 3L), dimnames = list(c("Sensitivity", "Specificity", 
"Pos Pred Value", "Neg Pred Value", "Precision", "Recall", "F1", 
"Prevalence", "Detection Rate", "Detection Prevalence", "Balanced Accuracy"
), c("Class: setosa", "Class: versicolor", "Class: virginica"
))), Fold05 = structure(c(1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 
0.333333333333333, 0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
0.333333333333333, 0.333333333333333, 0.333333333333333, 1, 1, 
1, 1, 1, 1, 1, 1, 0.333333333333333, 0.333333333333333, 0.333333333333333, 
1), dim = c(11L, 3L), dimnames = list(c("Sensitivity", "Specificity", 
"Pos Pred Value", "Neg Pred Value", "Precision", "Recall", "F1", 
"Prevalence", "Detection Rate", "Detection Prevalence", "Balanced Accuracy"
), c("Class: setosa", "Class: versicolor", "Class: virginica"
))), Fold06 = structure(c(1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 
0.333333333333333, 0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
0.333333333333333, 0.333333333333333, 0.333333333333333, 1, 1, 
1, 1, 1, 1, 1, 1, 0.333333333333333, 0.333333333333333, 0.333333333333333, 
1), dim = c(11L, 3L), dimnames = list(c("Sensitivity", "Specificity", 
"Pos Pred Value", "Neg Pred Value", "Precision", "Recall", "F1", 
"Prevalence", "Detection Rate", "Detection Prevalence", "Balanced Accuracy"
), c("Class: setosa", "Class: versicolor", "Class: virginica"
))), Fold07 = structure(c(1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 
0.333333333333333, 0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
0.333333333333333, 0.333333333333333, 0.333333333333333, 1, 1, 
1, 1, 1, 1, 1, 1, 0.333333333333333, 0.333333333333333, 0.333333333333333, 
1), dim = c(11L, 3L), dimnames = list(c("Sensitivity", "Specificity", 
"Pos Pred Value", "Neg Pred Value", "Precision", "Recall", "F1", 
"Prevalence", "Detection Rate", "Detection Prevalence", "Balanced Accuracy"
), c("Class: setosa", "Class: versicolor", "Class: virginica"
))), Fold08 = structure(c(1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 
0.333333333333333, 0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
0.333333333333333, 0.333333333333333, 0.333333333333333, 1, 1, 
1, 1, 1, 1, 1, 1, 0.333333333333333, 0.333333333333333, 0.333333333333333, 
1), dim = c(11L, 3L), dimnames = list(c("Sensitivity", "Specificity", 
"Pos Pred Value", "Neg Pred Value", "Precision", "Recall", "F1", 
"Prevalence", "Detection Rate", "Detection Prevalence", "Balanced Accuracy"
), c("Class: setosa", "Class: versicolor", "Class: virginica"
))), Fold09 = structure(c(1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 
0.333333333333333, 0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
0.333333333333333, 0.333333333333333, 0.333333333333333, 1, 1, 
1, 1, 1, 1, 1, 1, 0.333333333333333, 0.333333333333333, 0.333333333333333, 
1), dim = c(11L, 3L), dimnames = list(c("Sensitivity", "Specificity", 
"Pos Pred Value", "Neg Pred Value", "Precision", "Recall", "F1", 
"Prevalence", "Detection Rate", "Detection Prevalence", "Balanced Accuracy"
), c("Class: setosa", "Class: versicolor", "Class: virginica"
))), Fold10 = structure(c(1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 
0.333333333333333, 0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
0.333333333333333, 0.333333333333333, 0.333333333333333, 1, 1, 
1, 1, 1, 1, 1, 1, 0.333333333333333, 0.333333333333333, 0.333333333333333, 
1), dim = c(11L, 3L), dimnames = list(c("Sensitivity", "Specificity", 
"Pos Pred Value", "Neg Pred Value", "Precision", "Recall", "F1", 
"Prevalence", "Detection Rate", "Detection Prevalence", "Balanced Accuracy"
), c("Class: setosa", "Class: versicolor", "Class: virginica"
))))
2

There are 2 answers

0
Friede On BEST ANSWER

I assume the data to be small. Doing things twice doesn't really matter then:

classes = gsub("Class: ", "", colnames(transformed_conf_matrices[[1L]]))
transformed_conf_matrices = do.call(cbind, transformed_conf_matrices)
rowVars = \(x, ...) { rowSums((x - rowMeans(x, ...)) ^ 2L, ...) / (nrow(x) - 1L) }

Then we vapply:

> vapply(classes, \(i) 
+        rowMeans(transformed_conf_matrices[, grepl(i, colnames(transformed_conf_matrices))]), 
+        numeric(length = nrow(transformed_conf_matrices)))
                        setosa versicolor virginica
Sensitivity          1.0000000  1.0000000 1.0000000
Specificity          1.0000000  1.0000000 1.0000000
Pos Pred Value       1.0000000  1.0000000 1.0000000
Neg Pred Value       1.0000000  1.0000000 1.0000000
Precision            1.0000000  1.0000000 1.0000000
Recall               1.0000000  1.0000000 1.0000000
F1                   1.0000000  1.0000000 1.0000000
Prevalence           0.3333333  0.3333333 0.3333333
Detection Rate       0.3333333  0.3333333 0.3333333
Detection Prevalence 0.3333333  0.3333333 0.3333333
Balanced Accuracy    1.0000000  1.0000000 1.0000000
> vapply(classes, \(i) 
+        rowVars(transformed_conf_matrices[, grepl(i, colnames(transformed_conf_matrices))]), 
+        numeric(length = nrow(transformed_conf_matrices)))
                     setosa versicolor virginica
Sensitivity               0          0         0
Specificity               0          0         0
Pos Pred Value            0          0         0
Neg Pred Value            0          0         0
Precision                 0          0         0
Recall                    0          0         0
F1                        0          0         0
Prevalence                0          0         0
Detection Rate            0          0         0
Detection Prevalence      0          0         0
Balanced Accuracy         0          0         0

rowVars() from here.

0
Tou Mou On
# Assuming your data is stored in a list called 'transformed_conf_matrices'
library(dplyr)

# Extract names of metrics and classes
metric_names <- rownames(transformed_conf_matrices[[1]])
class_names <- colnames(transformed_conf_matrices[[1]])

# Initialize an empty data frame to store the results
results <- data.frame(Metric = character(), Class = character(), Mean = numeric(), Variance = numeric())

# Loop over each metric and class
for (metric in metric_names) {
  for (class in class_names) {
    # Extract values for the current metric and class across all folds
    values <- sapply(transformed_conf_matrices, function(x) x[metric, class])

    # Calculate mean and variance
    mean_val <- mean(values)
    var_val <- var(values)

    # Append to the results data frame
    results <- rbind(results, data.frame(Metric = metric, Class = class, Mean = mean_val, Variance = var_val))
  }
}

# Display the final table
print(results)

Output :

> # Display the final table
> print(results)
                 Metric             Class      Mean Variance
1           Sensitivity     Class: setosa 1.0000000        0
2           Sensitivity Class: versicolor 1.0000000        0
3           Sensitivity  Class: virginica 1.0000000        0
4           Specificity     Class: setosa 1.0000000        0
5           Specificity Class: versicolor 1.0000000        0
6           Specificity  Class: virginica 1.0000000        0
7        Pos Pred Value     Class: setosa 1.0000000        0
8        Pos Pred Value Class: versicolor 1.0000000        0
9        Pos Pred Value  Class: virginica 1.0000000        0
10       Neg Pred Value     Class: setosa 1.0000000        0
11       Neg Pred Value Class: versicolor 1.0000000        0
12       Neg Pred Value  Class: virginica 1.0000000        0
13            Precision     Class: setosa 1.0000000        0
14            Precision Class: versicolor 1.0000000        0
15            Precision  Class: virginica 1.0000000        0
16               Recall     Class: setosa 1.0000000        0
17               Recall Class: versicolor 1.0000000        0
18               Recall  Class: virginica 1.0000000        0
19                   F1     Class: setosa 1.0000000        0
20                   F1 Class: versicolor 1.0000000        0
21                   F1  Class: virginica 1.0000000        0
22           Prevalence     Class: setosa 0.3333333        0
23           Prevalence Class: versicolor 0.3333333        0
24           Prevalence  Class: virginica 0.3333333        0
25       Detection Rate     Class: setosa 0.3333333        0
26       Detection Rate Class: versicolor 0.3333333        0
27       Detection Rate  Class: virginica 0.3333333        0
28 Detection Prevalence     Class: setosa 0.3333333        0
29 Detection Prevalence Class: versicolor 0.3333333        0
30 Detection Prevalence  Class: virginica 0.3333333        0
31    Balanced Accuracy     Class: setosa 1.0000000        0
32    Balanced Accuracy Class: versicolor 1.0000000        0
33    Balanced Accuracy  Class: virginica 1.0000000        0