“The XAItest package includes several classic feature importance algorithms and supports the addition of new ones. To integrate an XGBoost model and generate its feature importance metrics using the SHAP package shapr.
The function should accept:
The function should return:
# Load the libraries
library(XAItest)
library(ggplot2)
library(ggforce)
library(SummarizedExperiment)
se_path <- system.file("extdata", "seClassif.rds", package="XAItest")
dataset_classif <- readRDS(se_path)
data_matrix <- assay(dataset_classif, "counts")
data_matrix <- t(data_matrix)
metadata <- as.data.frame(colData(dataset_classif))
df_simu_classif <- as.data.frame(cbind(data_matrix, y = metadata[['y']]))
for (col in names(df_simu_classif)) {
if (col != 'y') {
df_simu_classif[[col]] <- as.numeric(df_simu_classif[[col]])
}
}
featureImportanceXGBoost <- function(df, y="y", ...){
# Prepare data
matX <- as.matrix(df[, colnames(df) != y])
vecY <- df[[y]]
vecY <- as.character(vecY)
vecY[vecY == unique(vecY)[1]] <- 0
vecY[vecY == unique(vecY)[2]] <- 1
vecY <- as.numeric(vecY)
# Train the XGBoost model
model <- xgboost::xgboost(data = matX, label = vecY,
nrounds = 10, verbose = FALSE)
modelPredictions <- predict(model, matX)
modelPredictionsCat <- modelPredictions
modelPredictionsCat[modelPredictions < 0.5] <-
unique(as.character(df[[y]]))[1]
modelPredictionsCat[modelPredictions >= 0.5] <-
unique(as.character(df[[y]]))[2]
# Specifying the phi_0, i.e. the expected prediction without any features
p <- mean(vecY)
# Computing the actual Shapley values with kernelSHAP accounting
# for feature dependence using the empirical (conditional)
# distribution approach with bandwidth parameter sigma = 0.1 (default)
explainer <- shapr::shapr(matX, model, n_combinations = 200)
explanation <- shapr::explain(
matX,
approach = "empirical",
explainer = explainer,
prediction_zero = p,
n_combinations = 1000
)
results <- colMeans(abs(explanation$dt), na.rm = TRUE)
list(featImps = results, model = model,
modelPredictions=modelPredictionsCat)
}
set.seed(123)
results <- XAI.test(dataset_classif,"y", simData = TRUE,
simPvalTarget = 0.0005,
customFeatImps=
list("XGB_SHAP_feat_imp"=featureImportanceXGBoost),
)
## The specified model provides feature classes that are NA. The classes of data are taken as the truth.
The mapPvalImportance function reveals that both the custom XGB_SHAP_feat_imp and other feature importance metrics identify the biDistrib feature as significant.
mapPvalImportance(results, refPvalColumn = "ttest_adjPval", refPval = 0.0005)
# Plot of the XGboost generated model
plotModel(results, "XGB_SHAP_feat_imp", "diff_distrib01", "biDistrib")
sessionInfo()
## R Under development (unstable) (2024-10-21 r87258)
## Platform: x86_64-pc-linux-gnu
## Running under: Ubuntu 24.04.1 LTS
##
## Matrix products: default
## BLAS: /home/biocbuild/bbs-3.21-bioc/R/lib/libRblas.so
## LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.12.0
##
## locale:
## [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
## [3] LC_TIME=en_GB LC_COLLATE=C
## [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
## [7] LC_PAPER=en_US.UTF-8 LC_NAME=C
## [9] LC_ADDRESS=C LC_TELEPHONE=C
## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
##
## time zone: America/New_York
## tzcode source: system (glibc)
##
## attached base packages:
## [1] stats4 stats graphics grDevices utils datasets methods
## [8] base
##
## other attached packages:
## [1] caret_6.0-94 lattice_0.22-6
## [3] SummarizedExperiment_1.37.0 Biobase_2.67.0
## [5] GenomicRanges_1.59.1 GenomeInfoDb_1.43.2
## [7] IRanges_2.41.2 S4Vectors_0.45.2
## [9] BiocGenerics_0.53.3 generics_0.1.3
## [11] MatrixGenerics_1.19.0 matrixStats_1.4.1
## [13] gridExtra_2.3 ggforce_0.4.2
## [15] ggplot2_3.5.1 XAItest_0.99.8
##
## loaded via a namespace (and not attached):
## [1] pROC_1.18.5 rlang_1.1.4 magrittr_2.0.3
## [4] e1071_1.7-16 compiler_4.5.0 lime_0.5.3
## [7] vctrs_0.6.5 reshape2_1.4.4 stringr_1.5.1
## [10] shape_1.4.6.1 pkgconfig_2.0.3 crayon_1.5.3
## [13] XVector_0.47.0 labeling_0.4.3 utf8_1.2.4
## [16] markdown_1.13 prodlim_2024.06.25 UCSC.utils_1.3.0
## [19] purrr_1.0.2 xfun_0.49 glmnet_4.1-8
## [22] randomForest_4.7-1.2 zlibbioc_1.53.0 shapr_0.2.2
## [25] jsonlite_1.8.9 recipes_1.1.0 DelayedArray_0.33.3
## [28] tweenr_2.0.3 parallel_4.5.0 R6_2.5.1
## [31] stringi_1.8.4 limma_3.63.2 parallelly_1.40.1
## [34] rpart_4.1.23 xgboost_1.7.8.1 lubridate_1.9.4
## [37] Rcpp_1.0.13-1 assertthat_0.2.1 iterators_1.0.14
## [40] knitr_1.49 future.apply_1.11.3 Matrix_1.7-1
## [43] splines_4.5.0 nnet_7.3-19 timechange_0.3.0
## [46] tidyselect_1.2.1 yaml_2.3.10 abind_1.4-8
## [49] timeDate_4041.110 codetools_0.2-20 listenv_0.9.1
## [52] tibble_3.2.1 plyr_1.8.9 withr_3.0.2
## [55] evaluate_1.0.1 future_1.34.0 survival_3.7-0
## [58] proxy_0.4-27 polyclip_1.10-7 pillar_1.9.0
## [61] kernelshap_0.7.0 foreach_1.5.2 commonmark_1.9.2
## [64] munsell_0.5.1 scales_1.3.0 globals_0.16.3
## [67] class_7.3-22 glue_1.8.0 tools_4.5.0
## [70] data.table_1.16.4 ModelMetrics_1.2.2.2 gower_1.0.1
## [73] grid_4.5.0 ipred_0.9-15 colorspace_2.1-1
## [76] nlme_3.1-166 GenomeInfoDbData_1.2.13 cli_3.6.3
## [79] fansi_1.0.6 S4Arrays_1.7.1 lava_1.8.0
## [82] dplyr_1.1.4 gtable_0.3.6 digest_0.6.37
## [85] SparseArray_1.7.2 farver_2.1.2 lifecycle_1.0.4
## [88] hardhat_1.4.0 httr_1.4.7 mime_0.12
## [91] statmod_1.5.0 MASS_7.3-61