Skip to contents

Overview

Pratola et al. (2020) introduced the heteroskedastic BART model that asserts YN(f(𝐗),σ2(𝐗)), Y \sim N(f(\boldsymbol{\mathbf{X}}), \sigma^{2}(\boldsymbol{\mathbf{X}})), for some unknown functions ff and σ2.\sigma^{2}. They approximated the mean function ff as a sum of regression trees and the standard deviation function σ\sigma as a product of regression trees. They placed conditionally conjugate inverse gamma priors on the leaf node parameters in the σ\sigma-ensemble. This allowed them to follow the same two-step regression tree update strategy introduced in Chipman et al. (1998) and used in the original BART paper): first update the tree structure marginally with Metropolis-Hastings and then conditionally draw new leaf node parameters from a conjugate distribution.

flexBART supports fitting heteroskedastic BART models but uses a slightly different prior specification and computational strategy than Pratola et al. (2020). First, flexBART models logσ(𝐗)\log \sigma(\boldsymbol{\mathbf{X}}) as a sum of regression trees and specifies independent mean-zero normal priors on the leaf node parameters for those trees. Because this specification is not conditionally conjugate, we cannot use the conventional two-step regression tree update. Instead, flexBART follows Linero (2025) and updates both the tree structure and leaf outputs in a single Metropolis-Hastings step. New leaf node parameters are proposed from a Laplace approximation of their conditional posterior distribution given the proposed tree structure.

Fitting a heteroskedastic BART model using flexBART is nearly identical to fitting a basic BART model with one important difference: we must specify a formula argument and provide data via the training_data argument. The key difference is that for heteroskedastic models, we need to include a sigma() term on the right-hand side of the ~ in the formula argument: ~ bart(.) + sigma(.).

Illustration

To illustrate, we first generate data from a heteroskedastic regression model YN(4X2,e2X/5), Y \sim N(4X^{2}, e^{2X}/5), where X[0,1].X \in [0,1].

n_train <- 10000

set.seed(102)
train_data <- data.frame(Y = rep(NA, times = n_train))
train_data[,"X1"] <- runif(n = n_train, min = 0, max = 1)

mu_true <- function(df){
  return(4 * df[,"X1"]^2)
}
sigma_true <- function(df){
  return( exp(2 * df[,"X1"])/5)
}

mu_train <- mu_true(train_data)
sigma_train <- sigma_true(train_data)

train_data$Y <- mu_train + sigma_train * rnorm(n = n_train, mean = 0, sd = 1)

We can now fit the model.

library(flexBART)

fit <- flexBART(formula = Y ~ bart(.) + sigma(.),
                train_data = train_data)

Computing Posterior Summaries

In addition to its yhat.train and yhat.train.mean attributions, which respectively contains MCMC draws and posterior means of the f(𝐱i)f(\boldsymbol{\mathbf{x}}_{i})’s, fit contains two additional attributions, sigma.train and sigma.train.mean, which respectively contains MCMC draws and posterior means of the σ(𝐱i)\sigma(\boldsymbol{\mathbf{x}}_{i})’s.

View Code
oi_colors <- palette.colors(palette = "Okabe-Ito")
par(mar = c(3,3,2,1), mgp = c(1.8, 0.5, 0), mfrow = c(1,2))
f_limits <- range(c(mu_train, fit$yhat.train.mean))
plot(mu_train, fit$yhat.train.mean, 
     pch = 16, cex = 0.5,
     xlim= f_limits, ylim = f_limits,
     xlab = "Actual", ylab = "Posterior Mean",
     main = "Regression Function")
abline(a = 0, b = 1, col = oi_colors[3])

sigma_limits <- range(c(sigma_train, fit$sigma.train.mean))
plot(sigma_train, fit$sigma.train.mean, 
     pch = 16, cex = 0.5,
     xlim= sigma_limits, ylim = sigma_limits,
     xlab = "Actual", ylab = "Posterior Mean",
     main = "Standard Deviation Function")
abline(a = 0, b = 1, col = oi_colors[3])
Figure 1: Actual (horiztonal) vs fitted (vertical) values of regression (left) and residual standard deviation function (right) functions evaluated at training observations.

The coverage of the pointwise 95% posterior credible intervals for the f(𝐱i)f(\boldsymbol{\mathbf{x}}_{i})’s and σ(𝐱i)\sigma(\boldsymbol{\mathbf{x}}_{i})’s is also quite high.

train_mean_summary <-
  data.frame(
    MEAN = apply(fit$yhat.train, MARGIN = 2, FUN = mean),
    L95 = apply(fit$yhat.train, MARGIN = 2, FUN = quantile, probs = 0.025),
    U95 = apply(fit$yhat.train, MARGIN = 2, FUN = quantile, probs = 0.975))

train_sigma_summary <-
  data.frame(
    MEAN = apply(fit$sigma.train, MARGIN = 2, FUN = mean),
    L95 = apply(fit$sigma.train, MARGIN = 2, FUN = quantile, probs = 0.025),
    U95 = apply(fit$sigma.train, MARGIN = 2, FUN = quantile, probs = 0.975))


cat("Coverage for f:", round ( mean( mu_train >= train_mean_summary$L95 & mu_train <= train_mean_summary$U95), digits = 3), "\n")
#> Coverage for f: 0.956
cat("Coverage for sigma:", round ( mean( sigma_train >= train_sigma_summary$L95 & sigma_train <= train_sigma_summary$U95), digits = 3), "\n")
#> Coverage for sigma: 1

Posterior Predictive Simulation

Posterior predictive simulation with heteroskedastic BART proceeds almost exactly like it does with homoskedastic BART. In the code below, we loop over the training observations and for each MCMC sample of (f(𝐱),σ(𝐱))(f(\boldsymbol{\mathbf{x}}), \sigma(\boldsymbol{\mathbf{x}})), we draw an independent ϵN(0,1)\epsilon^{\star} \sim N(0,1) and compute f(𝐱)+σ(𝐱)ϵ.f(\boldsymbol{\mathbf{x}}) + \sigma(\boldsymbol{\mathbf{x}})\epsilon^{\star}.

nd <- nrow(fit$yhat.train)
ystar_train <- matrix(nrow = nd, ncol = n_train)
for(i in 1:n_train){
  ystar_train[,i] <- fit$yhat.train[,i] + rnorm(n = nd , mean = 0, sd = fit$sigma.train[,i])
}

ystar_quantiles <- 
  apply(ystar_train, MARGIN = 2, 
        FUN = quantile, probs = c(0.025, 0.975)) |>
  t()
mean( train_data$Y >= ystar_quantiles[,"2.5%"] & train_data$Y <= ystar_quantiles[,"97.5%"])
#> [1] 0.9564

Making Predictions

We can use flexBART’s prediction method (predict.flexBART(), accessed via the generic predict()) to make predictions about ff and σ\sigma evaluated at new inputs. In the code below, we create a grid of equally-spaced X1X_{1} values running from 0 to 1. We then use predict() to compute posterior samples of these evaluations, which we then compare to the actual values.

# Predicted Draws
test_df <- data.frame(X1 = seq(0, 1, by = 0.01))
test_pred <- predict(object = fit, newdata = test_df)

# Actual values
mu_test <- mu_true(test_df)
sigma_test <- sigma_true(test_df)

The yhat and sigma attributes of pred respectively contains MCMC samples for evaluations of the mean and standard deviation functions. In the code below, we compute the posterior means and 95% credible intervals for both functions on the test data.

test_mean_summary <- 
  data.frame(
    MEAN = apply(test_pred$yhat, MARGIN = 2, FUN = mean),
    L95 = apply(test_pred$yhat, MARGIN = 2, FUN = quantile, probs = 0.025),
    U95 = apply(test_pred$yhat, MARGIN = 2, FUN = quantile, probs = 0.975))
test_sigma_summary <-
  data.frame(
    MEAN = apply(test_pred$sigma, MARGIN = 2, FUN = mean),
    L95 = apply(test_pred$sigma, MARGIN = 2, FUN = quantile, probs = 0.025),
    U95 = apply(test_pred$sigma, MARGIN = 2, FUN = quantile, probs = 0.975))

Since both ff and σ\sigma depends only on X1,X_{1}, we can visualize the posterior mean and pointwise credible interval.

Code
f_limits <- range(c(mu_test, test_mean_summary))
sigma_limits <- range(c(sigma_test, test_sigma_summary))

par(mar = c(3,3,2,1), mgp = c(1.8, 0.5, 0), mfrow = c(1,2))

plot(1, type = "n", xlim = c(0,1), ylim = f_limits,
     main = "Mean Function", xlab = "X1", ylab = "f(X)")

polygon(x = c(test_df$X1, rev(test_df$X1)),
        y = c(test_mean_summary$L95, rev(test_mean_summary$U95)),
        col = adjustcolor(oi_colors[3], alpha.f = 0.5), border = NA)
lines(x = test_df$X1, y = test_mean_summary$MEAN, 
      col = oi_colors[3])
lines(test_df$X1, y = mu_test)

plot(1, type = "n", xlim = c(0,1), ylim = sigma_limits,
     main = "Std. Dev. Function", xlab = "X1", ylab = "f(X)")

polygon(x = c(test_df$X1, rev(test_df$X1)),
        y = c(test_sigma_summary$L95, rev(test_sigma_summary$U95)),
        col = adjustcolor(oi_colors[3], alpha.f = 0.5), border = NA)
lines(x = test_df$X1, y = test_sigma_summary$MEAN, 
      col = oi_colors[3])
lines(test_df$X1, y = sigma_test)
Figure 2: True mean (left) and standard deviation (right) functions (black) and the posterior mean and pointwise 95% credible interval (blue)