XAItest: Enhancing Feature Discovery with eXplainable AI

Ghislain FIEVET ghislain.fievet@gmail.com

Add a custom feature importance function

“The XAItest package includes several classic feature importance algorithms and supports the addition of new ones. To integrate an XGBoost model and generate its feature importance metrics using the SHAP package shapr.

The following function structure is required

The function should accept:

The function should return:

Load libraries and classification dataset

# Load the libraries
library(XAItest)
library(ggplot2)
library(ggforce)
library(SummarizedExperiment)
se_path <- system.file("extdata", "seClassif.rds", package="XAItest")
dataset_classif <- readRDS(se_path)

data_matrix <- assay(dataset_classif, "counts")
data_matrix <- t(data_matrix)
metadata <- as.data.frame(colData(dataset_classif))
df_simu_classif <- as.data.frame(cbind(data_matrix, y = metadata[['y']]))
for (col in names(df_simu_classif)) {
    if (col != 'y') {
        df_simu_classif[[col]] <- as.numeric(df_simu_classif[[col]])
    }
}

Build and use the custom feature importance function

featureImportanceXGBoost <- function(df, y="y", ...){
    # Prepare data
    matX <- as.matrix(df[, colnames(df) != y])
    vecY <- df[[y]]
    vecY <- as.character(vecY)
    vecY[vecY == unique(vecY)[1]] <- 0
    vecY[vecY == unique(vecY)[2]] <- 1
    vecY <- as.numeric(vecY)
    
    # Train the XGBoost model
    model <- xgboost::xgboost(data = matX, label = vecY,
                                nrounds = 10, verbose = FALSE)
    modelPredictions <- predict(model, matX)
    modelPredictionsCat <- modelPredictions
    modelPredictionsCat[modelPredictions < 0.5] <-
                                unique(as.character(df[[y]]))[1]
    modelPredictionsCat[modelPredictions >= 0.5] <-
                                unique(as.character(df[[y]]))[2]

    # Specifying the phi_0, i.e. the expected prediction without any features
    p <- mean(vecY)
    # Computing the actual Shapley values with kernelSHAP accounting
    # for feature dependence using the empirical (conditional)
    # distribution approach with bandwidth parameter sigma = 0.1 (default)
    explainer <- shapr::shapr(matX, model, n_combinations = 200)
    explanation <- shapr::explain(
        matX,
        approach = "empirical",
        explainer = explainer,
        prediction_zero = p,
        n_combinations = 1000
    )
    results <- colMeans(abs(explanation$dt), na.rm = TRUE)
    
    list(featImps = results, model = model,
            modelPredictions=modelPredictionsCat)
}
set.seed(123)
results <- XAI.test(dataset_classif,"y", simData = TRUE,
                   simPvalTarget = 0.0005,
                   customFeatImps=
                   list("XGB_SHAP_feat_imp"=featureImportanceXGBoost),
                  )
## The specified model provides feature classes that are NA. The classes of data are taken as the truth.

The mapPvalImportance function reveals that both the custom XGB_SHAP_feat_imp and other feature importance metrics identify the biDistrib feature as significant.

mapPvalImportance(results, refPvalColumn = "ttest_adjPval", refPval = 0.0005)

# Plot of the XGboost generated model
plotModel(results, "XGB_SHAP_feat_imp", "diff_distrib01", "biDistrib")

plot of chunk unnamed-chunk-2

sessionInfo()
## R Under development (unstable) (2024-10-21 r87258)
## Platform: x86_64-pc-linux-gnu
## Running under: Ubuntu 24.04.1 LTS
## 
## Matrix products: default
## BLAS:   /home/biocbuild/bbs-3.21-bioc/R/lib/libRblas.so 
## LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.12.0
## 
## locale:
##  [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
##  [3] LC_TIME=en_GB              LC_COLLATE=C              
##  [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
##  [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
##  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       
## 
## time zone: America/New_York
## tzcode source: system (glibc)
## 
## attached base packages:
## [1] stats4    stats     graphics  grDevices utils     datasets  methods  
## [8] base     
## 
## other attached packages:
##  [1] caret_6.0-94                lattice_0.22-6             
##  [3] SummarizedExperiment_1.37.0 Biobase_2.67.0             
##  [5] GenomicRanges_1.59.1        GenomeInfoDb_1.43.2        
##  [7] IRanges_2.41.2              S4Vectors_0.45.2           
##  [9] BiocGenerics_0.53.3         generics_0.1.3             
## [11] MatrixGenerics_1.19.0       matrixStats_1.4.1          
## [13] gridExtra_2.3               ggforce_0.4.2              
## [15] ggplot2_3.5.1               XAItest_0.99.8             
## 
## loaded via a namespace (and not attached):
##  [1] pROC_1.18.5             rlang_1.1.4             magrittr_2.0.3         
##  [4] e1071_1.7-16            compiler_4.5.0          lime_0.5.3             
##  [7] vctrs_0.6.5             reshape2_1.4.4          stringr_1.5.1          
## [10] shape_1.4.6.1           pkgconfig_2.0.3         crayon_1.5.3           
## [13] XVector_0.47.0          labeling_0.4.3          utf8_1.2.4             
## [16] markdown_1.13           prodlim_2024.06.25      UCSC.utils_1.3.0       
## [19] purrr_1.0.2             xfun_0.49               glmnet_4.1-8           
## [22] randomForest_4.7-1.2    zlibbioc_1.53.0         shapr_0.2.2            
## [25] jsonlite_1.8.9          recipes_1.1.0           DelayedArray_0.33.3    
## [28] tweenr_2.0.3            parallel_4.5.0          R6_2.5.1               
## [31] stringi_1.8.4           limma_3.63.2            parallelly_1.40.1      
## [34] rpart_4.1.23            xgboost_1.7.8.1         lubridate_1.9.4        
## [37] Rcpp_1.0.13-1           assertthat_0.2.1        iterators_1.0.14       
## [40] knitr_1.49              future.apply_1.11.3     Matrix_1.7-1           
## [43] splines_4.5.0           nnet_7.3-19             timechange_0.3.0       
## [46] tidyselect_1.2.1        yaml_2.3.10             abind_1.4-8            
## [49] timeDate_4041.110       codetools_0.2-20        listenv_0.9.1          
## [52] tibble_3.2.1            plyr_1.8.9              withr_3.0.2            
## [55] evaluate_1.0.1          future_1.34.0           survival_3.7-0         
## [58] proxy_0.4-27            polyclip_1.10-7         pillar_1.9.0           
## [61] kernelshap_0.7.0        foreach_1.5.2           commonmark_1.9.2       
## [64] munsell_0.5.1           scales_1.3.0            globals_0.16.3         
## [67] class_7.3-22            glue_1.8.0              tools_4.5.0            
## [70] data.table_1.16.4       ModelMetrics_1.2.2.2    gower_1.0.1            
## [73] grid_4.5.0              ipred_0.9-15            colorspace_2.1-1       
## [76] nlme_3.1-166            GenomeInfoDbData_1.2.13 cli_3.6.3              
## [79] fansi_1.0.6             S4Arrays_1.7.1          lava_1.8.0             
## [82] dplyr_1.1.4             gtable_0.3.6            digest_0.6.37          
## [85] SparseArray_1.7.2       farver_2.1.2            lifecycle_1.0.4        
## [88] hardhat_1.4.0           httr_1.4.7              mime_0.12              
## [91] statmod_1.5.0           MASS_7.3-61