mc.crisk2.pwbart {BART} | R Documentation |
BART is a Bayesian “sum-of-trees” model.
For a numeric response y, we have
y = f(x) + e,
where e ~ N(0,sigma^2).
f is the sum of many tree models. The goal is to have very flexible inference for the uknown function f.
In the spirit of “ensemble models”, each tree is constrained by a prior to be a weak learner so that it contributes a small amount to the overall fit.
mc.crisk2.pwbart( x.test, x.test2, treedraws, treedraws2, binaryOffset=0, binaryOffset2=0, mc.cores=2L, type='pbart', transposed=FALSE, nice=19L )
x.test |
Matrix of covariates to predict y for cause 1. |
x.test2 |
Matrix of covariates to predict y for cause 2. |
treedraws |
|
treedraws2 |
|
binaryOffset |
Mean to add on to y prediction for cause 1. |
binaryOffset2 |
Mean to add on to y prediction for cause 2. |
mc.cores |
Number of threads to utilize. |
type |
Whether to employ Albert-Chib, |
transposed |
When running |
nice |
Set the job niceness. The default niceness is 19: niceness goes from 0 (highest) to 19 (lowest). |
BART is an Bayesian MCMC method. At each MCMC interation, we produce a draw from the joint posterior (f,sigma) \| (x,y) in the numeric y case and just f in the binary y case.
Thus, unlike a lot of other modelling methods in R, we do not produce a single model object from which fits and summaries may be extracted. The output consists of values f*(x) (and sigma* in the numeric case) where * denotes a particular draw. The x is either a row from the training data (x.train) or the test data (x.test).
Returns an object of type crisk2bart
which is essentially a list with components:
yhat.test |
A matrix with ndpost rows and nrow(x.test) columns.
Each row corresponds to a draw f* from the posterior of f
and each column corresponds to a row of x.train.
The (i,j) value is f*(x) for the i\^th kept draw of f
and the j\^th row of x.train. |
surv.test |
test data fits for survival probability. |
surv.test.mean |
mean of |
prob.test |
The probability of suffering cause 1 which is occasionally useful, e.g., in calculating the concordance. |
prob.test2 |
The probability of suffering cause 2 which is occasionally useful, e.g., in calculating the concordance. |
cif.test |
The cumulative incidence function of cause 1, F_1(t, x), where x's are the rows of the test data. |
cif.test2 |
The cumulative incidence function of cause 2, F_2(t, x), where x's are the rows of the test data. |
yhat.test.mean |
test data fits = mean of yhat.test columns. |
cif.test.mean |
mean of |
cif.test2.mean |
mean of |
Robert McCulloch: robert.e.mcculloch@gmail.com,
Rodney Sparapani: rsparapa@mcw.edu.
Sparapani, R., Logan, B., McCulloch, R., and Laud, P. (2016) Nonparametric survival analysis using Bayesian Additive Regression Trees (BART). Statistics in Medicine, 16:2741-53 <doi:10.1002/sim.6893>.
pwbart
, crisk2.bart
, mc.crisk2.bart
library(BART) data(transplant) delta <- (as.numeric(transplant$event)-1) ## recode so that delta=1 is cause of interest; delta=2 otherwise delta[delta==1] <- 4 delta[delta==2] <- 1 delta[delta>1] <- 2 table(delta, transplant$event) times <- pmax(1, ceiling(transplant$futime/7)) ## weeks ##times <- pmax(1, ceiling(transplant$futime/30.5)) ## months table(times) typeO <- 1*(transplant$abo=='O') typeA <- 1*(transplant$abo=='A') typeB <- 1*(transplant$abo=='B') typeAB <- 1*(transplant$abo=='AB') table(typeA, typeO) x.train <- cbind(typeO, typeA, typeB, typeAB) x.test <- cbind(1, 0, 0, 0) dimnames(x.test)[[2]] <- dimnames(x.train)[[2]] ## parallel::mcparallel/mccollect do not exist on windows if(.Platform$OS.type=='unix') { ##test BART with token run to ensure installation works post <- mc.crisk2.bart(x.train=x.train, times=times, delta=delta, seed=99, mc.cores=2, nskip=5, ndpost=5, keepevery=1) pre <- surv.pre.bart(x.train=x.train, x.test=x.test, times=times, delta=delta) K <- post$K pred <- mc.crisk2.pwbart(pre$tx.test, pre$tx.test, post$treedraws, post$treedraws2, post$binaryOffset, post$binaryOffset2) } ## Not run: ## run one long MCMC chain in one process ## set.seed(99) ## post <- crisk2.bart(x.train=x.train, times=times, delta=delta, x.test=x.test) ## in the interest of time, consider speeding it up by parallel processing ## run "mc.cores" number of shorter MCMC chains in parallel processes post <- mc.crisk2.bart(x.train=x.train, times=times, delta=delta, x.test=x.test, seed=99, mc.cores=8) check <- mc.crisk2.pwbart(post$tx.test, post$tx.test, post$treedraws, post$treedraws2, post$binaryOffset, post$binaryOffset2, mc.cores=8) ## check <- predict(post, newdata=post$tx.test, newdata2=post$tx.test2, ## mc.cores=8) print(c(post$surv.test.mean[1], check$surv.test.mean[1], post$surv.test.mean[1]-check$surv.test.mean[1]), digits=22) print(all(round(post$surv.test.mean, digits=9)== round(check$surv.test.mean, digits=9))) print(c(post$cif.test.mean[1], check$cif.test.mean[1], post$cif.test.mean[1]-check$cif.test.mean[1]), digits=22) print(all(round(post$cif.test.mean, digits=9)== round(check$cif.test.mean, digits=9))) print(c(post$cif.test2.mean[1], check$cif.test2.mean[1], post$cif.test2.mean[1]-check$cif.test2.mean[1]), digits=22) print(all(round(post$cif.test2.mean, digits=9)== round(check$cif.test2.mean, digits=9))) ## End(Not run)