Treeffuser is a new method for probabilistic predictions: it provides predictive distributions rather than single point predictions. Treeffuser is fast, accurate, flexible, and easy to use!

A standard task in many fields is to predict a target variable y given some input x. Algorithms usually provide point estimates, a single value of y for each x. However, it is important for most real world applications to not only have a point estimate but rather a measure of uncertainty associated with that estimate. A natural way to quantify uncertainty is to provide a probability distribution p(y | x) of y given x. This is the task of probabilistic prediction.

With the distribution p(y | x) in hand, we can compute many quantities about y :

and many other quantities!

The main difficulty is that in practice we don't have access to p(y|x) but only to a dataset of samples from this distribution. How does one go about estimating p(y|x) from a dataset?

Example of probabilistic predictions p(y|x). Left: samples (y,x). Right: learned conditional distribution p(y|x) for a given x.
The value of x can be changed by dragging the slider.

Treeffuser

Treeffuser is a method for producing such probabilistic predictions. It estimates p(y|x) by combining two powerful techniques:

  1. Conditional diffusions: They enable Treeffuser to model essentially any distributionp(y|x) without requiring special assumptions from the user.
  2. Gradient boosted trees: They make Treeffuser fast, easy to use and well suited for tabular data problems, the most common type of data in the real world.

These two techniques combined make treeffuser a powerful tool for probabilistic predictions.

Flexibility

With conditional diffusions, Treeffuser can model, heavy tailed, multimodal, skewed and any other complex distributions.

Multivariate output

Treeffuser supports multivariate output y, allowing for the modeling of complex correlations between multiple output variables.

Minimal Tuning

Treeffuser works out of the box with the default hyperparameters. It does not require intense tuning.

Easy to use

Treeffuser adopts the API of sklearn. To use it, simply call model.fit(X,y) and model.sample(X).

Designed for tabular data

Treeffuser is built on top LightGBM, a fast and efficient implementation of gradient boosted trees has state-of-the-art performance on tabular data problems.

Missing or categorical data

Treeffuser natively implements principled methods for handling missing data at training and inference time, making it easy to use in practice and robust to real world data.

Competitive performance

Treeffuser has state of the art performance on a variety of benchmarks. And is competitive with other state of the art methods for probabilistic predictions.
dataset N, dx, dy Deep ensembles NGBost Gaussian iBUG XGBoost Quantile regression DRF Treeffuser Treeffuser no tuning
bike 17379,12,1 1.61±0.05 7.15±0.18 2.21±0.04 1.85±0.07 2.15±0.06 1.60±0.05 1.64±0.05
energy 768, 8, 2 5.00±0.71 4.78±0.49 NA NA 5.43±0.69 3.07±0.40 3.32±0.48
kin8nm 8192, 8, 1 3.59±0.12 9.48±0.29 7.14±0.18 6.63±0.16 9.44±0.20 5.89±0.14 5.88±0.17
movies 7415, 9, 1 2.94±0.35 X 3.18±0.29 7.90±0.68 5.57±0.61 2.68±0.31 2.69±0.28
naval 11934,17,1 4.11±0.39 4.43±0.19 4.70±0.16 16.86±2.46 4.55±0.16 2.02±0.08 2.46±0.07
news 39644,58,1 2.53±0.27 X 3.75±0.43 2.32±0.17 1.98±0.17 1.98±0.17 1.98±0.16
power 9568, 4, 1 2.06±0.10 2.01±0.13 1.71±0.09 5.40±0.12 1.90±0.11 1.49±0.07 1.52±0.07
super. 21263,81,1 4.89±0.31 5.24±0.50 4.43±0.24 3.79±0.14 4.32±0.52 3.52±0.13 3.60±0.15
wine 6497, 12, 1 3.59±0.10 3.82±0.11 3.47±0.13 3.24±0.14 3.30±0.12 2.59±0.13 2.67±0.13
yacht 308, 6, 1 4.86±1.38 3.67±1.47 3.17±1.13 3.53±1.30 7.73±2.43 3.11±0.99 3.39±0.97
Continuous Ranked Probability Score (CRPS) by dataset and method (lower is better). X indicates the method failed to run, and NA that the method is not directly applicable to multivariate outputs. Standard deviations are measured with 10-fold cross-validation. For each dataset, the two best methods are bolded. Treefuser provides the most accurate probabilistic predictions, even with default hyper-parameters (no tuning).
dataset N, dx, dy Deep ensembles NGBost Gaussian iBUG XGBoost Quantile regression DRF Treeffuser Treeffuser no tuning
bike 17379,12,1 3.70±0.13 11.52±0.30 4.16±0.10 4.63±0.21 4.86±0.18 3.69±0.14 3.81±0.15
energy 768, 8, 2 11.34±0.28 11.41±0.25 NA NA 13.73±1.69 8.32±1.35 8.79±1.46
kin8nm 8192, 8, 1 0.64±0.02 1.71±0.06 1.94±0.06 1.23±0.14 1.18±0.04 1.06±0.03 1.06±0.02
movies 7415, 9, 1 5.75±1.38 X 4.87±0.80 49.93±0.07 1.31±0.47 2.71±1.36 5.17±1.32
naval 11934,17,1 13.31±1.90 5.50±0.75 4.10±0.58 15.05±0.58 17.37±0.40 7.75±0.41 9.10±0.39
news 39644,58,1 1.95±2.66 X 1.21±0.37 1.12±0.40 1.08±0.41 1.10±0.40 1.09±0.41
power 9568, 4, 1 3.11±0.32 3.66±0.35 3.16±0.53 9.36±0.58 3.65±0.35 2.93±0.30 3.02±0.30
super. 21263,81,1 11.28±0.74 11.25±0.87 9.53±0.32 9.06±0.58 3.87±0.56 3.50±0.70 4.87±0.28
wine 6497, 12, 1 1.56±0.17 6.85±0.18 6.83±0.28 6.30±0.28 8.24±0.21 5.88±0.17 5.98±0.15
yacht 308, 6, 1 1.06±0.46 0.85±0.13 0.67±0.26 1.28±0.79 6.34±0.01 2.01±0.08 2.08±0.09
Root Mean Squared Error (RMSE) by dataset and method (lower is better). X indicates the method failed to run, and NA that the method is not directly applicable to multivariate outputs. Standard deviations are measured with 10-fold cross-validation. For each dataset, the two best methods are bolded. Treefuser provides the most accurate probabilistic predictions, even with default hyper-parameters (no tuning).
dataset N, dx, dy Deep ensembles NGBost Gaussian iBUG XGBoost Quantile regression DRF Treeffuser Treeffuser no tuning
bike 17379,12,1 2.38±0.97 17.79±0.68 7.68±0.89 2.95±0.29 3.45±0.43 1.86±0.87 1.73±0.62
energy 768, 8, 2 8.08±1.41 5.82±1.74 NA NA 4.54±0.74 5.33±1.42 3.11±1.14
kin8nm 8192, 8, 1 2.51±0.50 2.75±0.11 3.83±1.56 4.74±0.93 3.87±0.56 1.77±0.74 1.36±1.31
movies 7415, 9, 1 19.26±1.13 X 4.73±1.65 4.57±0.94 2.87±0.52 1.33±0.16 1.34±0.31
naval 11934,17,1 5.29±1.68 2.53±0.00 3.78±1.56 4.75±0.43 3.87±0.56 1.73±0.14 1.44±1.30
news 39644,58,1 21.53±0.63 X 13.57±1.40 18.96±0.42 1.14±0.25 4.87±0.28 3.50±0.70
power 9568, 4, 1 2.52±1.18 2.53±1.00 3.83±1.65 4.75±0.93 2.87±0.56 1.77±0.74 1.36±1.31
super. 21263,81,1 3.19±0.13 2.68±0.15 3.45±0.77 4.86±0.90 22.02±1.13 5.22±0.69 5.70±0.55
wine 6497, 12, 1 13.25±0.57 9.28±0.25 8.45±0.20 6.28±0.29 8.34±0.21 2.08±0.09 2.01±0.08
yacht 308, 6, 1 3.19±0.56 2.68±0.15 3.45±0.77 4.86±0.90 22.02±1.13 5.22±0.69 5.70±0.55
Mean Absolute Calibration Error (MACE) by dataset and method (lower is better). X indicates the method failed to run, and NA that the method is not directly applicable to multivariate outputs. Standard deviations are measured with 10-fold cross-validation. For each dataset, the two best methods are bolded. Treefuser has a competitive calibration error across most datasets.

Usage Examples

Bimodal, heteroskedastic response

A simple response with two sinusoidal components [colab] [source]

We demonstrate how to train and generate samples with Treeffuser using a simple example with synthetic data, featuring a multimodal response with two sinusoidal components.

Optimal inventory allocation

Optimal Inventory Allocation [colab] [source]

We use Treeffuser to model the distribution of demand for products in a retail store. This allows us to compute the optimal inventory allocation that minimizes the risk of stockouts and waste while maximizing profits.

Next steps

Installing and using Treeffuser is easy and simple. Our package is available on PyPI and can be installed with pip.

pip install treeffuser

We adhere to the scikit-learn API, so you can use Treeffuser as you would any other scikit-learn model.

from treeffuser import Treeffuser
# Load some data
X_train, X_test, y_train, y_test = ...

model = Treeffuser()
model.fit(X_train, y_train)
y_samples = model.sample(x, n_samples=100) # y_samples.shape[0] is 100

# Estimate downstream quantities of interest
y_mean = y_samples.mean(axis=0) # conditional mean for each x
y_std = y_samples.std(axis=0) # conditional std for each x

We hope you enjoy using Treeffuser! Please check out the paper, the documentation, and the code on GitHub for more information.

Docs Paper GitHub