## ----include = FALSE----------------------------------------------------------
knitr::opts_chunk$set(
    collapse = TRUE,
    comment = "#>",
    fig.path = "figures/",
    dpi = 36
)

## ----eval=TRUE, include=FALSE-------------------------------------------------
start.time <- Sys.time()

## ----eval=TRUE, message=FALSE, warning=FALSE----------------------------------
library(Banksy)
library(SummarizedExperiment)
library(SpatialExperiment)
library(Seurat)
library(scran)
library(data.table)
library(harmony)

library(scater)
library(cowplot)
library(ggplot2)

SEED <- 1000

## ----eval=T, message=FALSE, warning=FALSE-------------------------------------
library(spatialLIBD)
library(ExperimentHub)

ehub <- ExperimentHub::ExperimentHub()
spe <- spatialLIBD::fetch_data(type = "spe", eh = ehub)

## ----eval=T, message=FALSE, warning=FALSE-------------------------------------
#' Remove NA spots
na_id <- which(is.na(spe$layer_guess_reordered_short))
spe <- spe[, -na_id]

#' Trim
imgData(spe) <- NULL
assay(spe, "logcounts") <- NULL
reducedDims(spe) <- NULL
rowData(spe) <- NULL
colData(spe) <- DataFrame(
    sample_id = spe$sample_id,
    subject_id = factor(spe$sample_id, labels = rep(paste0("Subject", 1:3), each = 4)),
    clust_annotation = factor(as.numeric(spe$layer_guess_reordered_short)),
    in_tissue = spe$in_tissue,
    row.names = colnames(spe)
)
colnames(spe) <- paste0(colnames(spe), "_", spe$sample_id)
invisible(gc())

## ----eval=T, message=FALSE, warning=FALSE-------------------------------------
spe <- spe[, spe$sample_id %in% c("151507", "151669", "151673")]
sample_names <- unique(spe$sample_id)

## ----eval=T, message=FALSE, warning=FALSE-------------------------------------
#' Stagger spatial coordinates
locs <- spatialCoords(spe)
locs <- cbind(locs, sample_id = factor(spe$sample_id))
locs_dt <- data.table(locs)
colnames(locs_dt) <- c("sdimx", "sdimy", "group")
locs_dt[, sdimx := sdimx - min(sdimx), by = group]
global_max <- max(locs_dt$sdimx) * 1.5
locs_dt[, sdimx := sdimx + group * global_max]
locs <- as.matrix(locs_dt[, 1:2])
rownames(locs) <- colnames(spe)
spatialCoords(spe) <- locs

## ----eval=T, message=FALSE, warning=FALSE-------------------------------------
#' Get HVGs
seu <- as.Seurat(spe, data = NULL)
seu <- FindVariableFeatures(seu, nfeatures = 2000)

#' Normalize data
scale_factor <- median(colSums(assay(spe, "counts")))
seu <- NormalizeData(seu, scale.factor = scale_factor, normalization.method = "RC")

#' Add data to SpatialExperiment and subset to HVGs
aname <- "normcounts"
assay(spe, aname) <- GetAssayData(seu)
spe <- spe[VariableFeatures(seu),]

## ----eval=T, message=FALSE, warning=FALSE-------------------------------------
compute_agf <- TRUE
k_geom <- 18
spe <- computeBanksy(spe, assay_name = aname, compute_agf = compute_agf, k_geom = k_geom)

## ----pca, eval=T, message=FALSE, warning=FALSE--------------------------------
lambda <- 0.2
npcs <- 20
use_agf <- TRUE
spe <- runBanksyPCA(spe, use_agf = use_agf, lambda = lambda, npcs = npcs, seed = SEED)

## ----harmony, eval=T, message=FALSE, warning=FALSE----------------------------
set.seed(SEED)
harmony_embedding <- HarmonyMatrix(
    data_mat = reducedDim(spe, "PCA_M1_lam0.2"),
    meta_data = colData(spe),
    vars_use = c("sample_id", "subject_id"),
    do_pca = FALSE,
    max.iter.harmony = 20,
    verbose = FALSE
)
reducedDim(spe, "Harmony_BANKSY") <- harmony_embedding

## ----umap, eval=T, message=FALSE, warning=FALSE-------------------------------
spe <- runBanksyUMAP(spe, use_agf = TRUE, lambda = lambda, npcs = npcs)
spe <- runBanksyUMAP(spe, dimred = "Harmony_BANKSY")

## ----batch-correction-umap, eval=T, fig.height=3.5, out.width='90%', message=FALSE, warning=FALSE----
plot_grid(
    plotReducedDim(spe, "UMAP_M1_lam0.2", 
                   point_size = 0.1,
                   point_alpha = 0.5,
                   color_by = "subject_id") +
        theme(legend.position = "none"),
    plotReducedDim(spe, "UMAP_Harmony_BANKSY", 
                   point_size = 0.1,
                   point_alpha = 0.5,
                   color_by = "subject_id") +
        theme(legend.title = element_blank()) +
        guides(colour = guide_legend(override.aes = list(size = 5, alpha = 1))),
    nrow = 1,
    rel_widths = c(1, 1.2)
)

## ----eval=T, message=FALSE, warning=FALSE-------------------------------------
spe <- clusterBanksy(spe, dimred = "Harmony_BANKSY", resolution = 0.55, seed = SEED)
spe <- connectClusters(spe, map_to = "clust_annotation")

## ----batch-correction-spatial, eval=T, fig.height=4, out.width='90%', message=FALSE, warning=FALSE----
cnm <- clusterNames(spe)[2]
spatial_plots <- lapply(sample_names, function(snm) {
    x <- spe[, spe$sample_id == snm]
    ari <- aricode::ARI(x$clust_annotation, colData(x)[, cnm])
    
    df <- cbind.data.frame(clust=colData(x)[[cnm]], spatialCoords(x))
    ggplot(df, aes(x=sdimy, y=sdimx, col=clust)) +
        geom_point(size = 0.5) + 
        scale_color_manual(values = pals::kelly()[-1]) +
        theme_classic() + 
        theme(
            legend.position = "none",
            axis.text.x=element_blank(),
            axis.text.y=element_blank(),
            axis.ticks=element_blank(),
            axis.title.x=element_blank(),
            axis.title.y=element_blank()) +
        labs(title = sprintf("Sample %s - ARI: %s", snm, round(ari, 3))) +
        coord_equal()
})

plot_grid(plotlist = spatial_plots, ncol = 3, byrow = FALSE)

## ----eval=TRUE, echo=FALSE----------------------------------------------------
Sys.time() - start.time

## ----sess---------------------------------------------------------------------
sessionInfo()