Treeffuser is a new method for probabilistic predictions: it provides predictive distributions rather than single point predictions. Treeffuser is fast, accurate, flexible, and easy to use!
A standard task in many fields is to predict a target variable
With the distribution
and many other quantities!
The main difficulty is that in practice we don't have access to
Treeffuser is a method for producing such probabilistic predictions.
It estimates
These two techniques combined make treeffuser a powerful tool for probabilistic predictions.
With conditional diffusions, Treeffuser can model, heavy tailed, multimodal, skewed and any other complex distributions.
Treeffuser supports multivariate output
Treeffuser works out of the box with the default hyperparameters. It does not require intense tuning.
Treeffuser adopts the API of sklearn
. To use it, simply call model.fit(X,y)
and model.sample(X)
.
Treeffuser is built on top LightGBM
,
a fast and efficient implementation of gradient boosted trees
has state-of-the-art performance on tabular data problems.
Treeffuser natively implements principled methods for handling missing data at training and inference time, making it easy to use in practice and robust to real world data.
dataset | N, dx, dy | Deep ensembles | NGBost Gaussian | iBUG XGBoost | Quantile regression | DRF | Treeffuser | Treeffuser no tuning |
---|---|---|---|---|---|---|---|---|
bike | 17379,12,1 | 1.61±0.05 | 7.15±0.18 | 2.21±0.04 | 1.85±0.07 | 2.15±0.06 | 1.60±0.05 | 1.64±0.05 |
energy | 768, 8, 2 | 5.00±0.71 | 4.78±0.49 | NA | NA | 5.43±0.69 | 3.07±0.40 | 3.32±0.48 |
kin8nm | 8192, 8, 1 | 3.59±0.12 | 9.48±0.29 | 7.14±0.18 | 6.63±0.16 | 9.44±0.20 | 5.89±0.14 | 5.88±0.17 |
movies | 7415, 9, 1 | 2.94±0.35 | X | 3.18±0.29 | 7.90±0.68 | 5.57±0.61 | 2.68±0.31 | 2.69±0.28 |
naval | 11934,17,1 | 4.11±0.39 | 4.43±0.19 | 4.70±0.16 | 16.86±2.46 | 4.55±0.16 | 2.02±0.08 | 2.46±0.07 |
news | 39644,58,1 | 2.53±0.27 | X | 3.75±0.43 | 2.32±0.17 | 1.98±0.17 | 1.98±0.17 | 1.98±0.16 |
power | 9568, 4, 1 | 2.06±0.10 | 2.01±0.13 | 1.71±0.09 | 5.40±0.12 | 1.90±0.11 | 1.49±0.07 | 1.52±0.07 |
super. | 21263,81,1 | 4.89±0.31 | 5.24±0.50 | 4.43±0.24 | 3.79±0.14 | 4.32±0.52 | 3.52±0.13 | 3.60±0.15 |
wine | 6497, 12, 1 | 3.59±0.10 | 3.82±0.11 | 3.47±0.13 | 3.24±0.14 | 3.30±0.12 | 2.59±0.13 | 2.67±0.13 |
yacht | 308, 6, 1 | 4.86±1.38 | 3.67±1.47 | 3.17±1.13 | 3.53±1.30 | 7.73±2.43 | 3.11±0.99 | 3.39±0.97 |
dataset | N, dx, dy | Deep ensembles | NGBost Gaussian | iBUG XGBoost | Quantile regression | DRF | Treeffuser | Treeffuser no tuning |
---|---|---|---|---|---|---|---|---|
bike | 17379,12,1 | 3.70±0.13 | 11.52±0.30 | 4.16±0.10 | 4.63±0.21 | 4.86±0.18 | 3.69±0.14 | 3.81±0.15 |
energy | 768, 8, 2 | 11.34±0.28 | 11.41±0.25 | NA | NA | 13.73±1.69 | 8.32±1.35 | 8.79±1.46 |
kin8nm | 8192, 8, 1 | 0.64±0.02 | 1.71±0.06 | 1.94±0.06 | 1.23±0.14 | 1.18±0.04 | 1.06±0.03 | 1.06±0.02 |
movies | 7415, 9, 1 | 5.75±1.38 | X | 4.87±0.80 | 49.93±0.07 | 1.31±0.47 | 2.71±1.36 | 5.17±1.32 |
naval | 11934,17,1 | 13.31±1.90 | 5.50±0.75 | 4.10±0.58 | 15.05±0.58 | 17.37±0.40 | 7.75±0.41 | 9.10±0.39 |
news | 39644,58,1 | 1.95±2.66 | X | 1.21±0.37 | 1.12±0.40 | 1.08±0.41 | 1.10±0.40 | 1.09±0.41 |
power | 9568, 4, 1 | 3.11±0.32 | 3.66±0.35 | 3.16±0.53 | 9.36±0.58 | 3.65±0.35 | 2.93±0.30 | 3.02±0.30 |
super. | 21263,81,1 | 11.28±0.74 | 11.25±0.87 | 9.53±0.32 | 9.06±0.58 | 3.87±0.56 | 3.50±0.70 | 4.87±0.28 |
wine | 6497, 12, 1 | 1.56±0.17 | 6.85±0.18 | 6.83±0.28 | 6.30±0.28 | 8.24±0.21 | 5.88±0.17 | 5.98±0.15 |
yacht | 308, 6, 1 | 1.06±0.46 | 0.85±0.13 | 0.67±0.26 | 1.28±0.79 | 6.34±0.01 | 2.01±0.08 | 2.08±0.09 |
dataset | N, dx, dy | Deep ensembles | NGBost Gaussian | iBUG XGBoost | Quantile regression | DRF | Treeffuser | Treeffuser no tuning |
---|---|---|---|---|---|---|---|---|
bike | 17379,12,1 | 2.38±0.97 | 17.79±0.68 | 7.68±0.89 | 2.95±0.29 | 3.45±0.43 | 1.86±0.87 | 1.73±0.62 |
energy | 768, 8, 2 | 8.08±1.41 | 5.82±1.74 | NA | NA | 4.54±0.74 | 5.33±1.42 | 3.11±1.14 |
kin8nm | 8192, 8, 1 | 2.51±0.50 | 2.75±0.11 | 3.83±1.56 | 4.74±0.93 | 3.87±0.56 | 1.77±0.74 | 1.36±1.31 |
movies | 7415, 9, 1 | 19.26±1.13 | X | 4.73±1.65 | 4.57±0.94 | 2.87±0.52 | 1.33±0.16 | 1.34±0.31 |
naval | 11934,17,1 | 5.29±1.68 | 2.53±0.00 | 3.78±1.56 | 4.75±0.43 | 3.87±0.56 | 1.73±0.14 | 1.44±1.30 |
news | 39644,58,1 | 21.53±0.63 | X | 13.57±1.40 | 18.96±0.42 | 1.14±0.25 | 4.87±0.28 | 3.50±0.70 |
power | 9568, 4, 1 | 2.52±1.18 | 2.53±1.00 | 3.83±1.65 | 4.75±0.93 | 2.87±0.56 | 1.77±0.74 | 1.36±1.31 |
super. | 21263,81,1 | 3.19±0.13 | 2.68±0.15 | 3.45±0.77 | 4.86±0.90 | 22.02±1.13 | 5.22±0.69 | 5.70±0.55 |
wine | 6497, 12, 1 | 13.25±0.57 | 9.28±0.25 | 8.45±0.20 | 6.28±0.29 | 8.34±0.21 | 2.08±0.09 | 2.01±0.08 |
yacht | 308, 6, 1 | 3.19±0.56 | 2.68±0.15 | 3.45±0.77 | 4.86±0.90 | 22.02±1.13 | 5.22±0.69 | 5.70±0.55 |
Installing and using Treeffuser is easy and simple. Our package is available on PyPI and can be installed with pip.
pip install treeffuser
We adhere to the scikit-learn API, so you can use Treeffuser as you would any other scikit-learn model.
from treeffuser import Treeffuser
# Load some data
X_train, X_test, y_train, y_test = ...
model = Treeffuser()
model.fit(X_train, y_train)
y_samples = model.sample(x, n_samples=100) # y_samples.shape[0] is 100
# Estimate downstream quantities of interest
y_mean = y_samples.mean(axis=0) # conditional mean for each x
y_std = y_samples.std(axis=0) # conditional std for each x
We hope you enjoy using Treeffuser! Please check out the paper, the documentation, and the code on GitHub for more information.