{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Forecasting Walmart sales with Treeffuser\n", "\n", "In this tutorial we show how to use Treeffuser to model and forecast Walmart sales using the M5 forecasting dataset from Kaggle." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Getting started\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To get started, we first install `treeffuser` and import the relevant libraries." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%%capture\n", "!pip install treeffuser\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "\n", "from pathlib import Path\n", "\n", "from tqdm import tqdm\n", "from treeffuser import Treeffuser\n", "\n", "# load autoreload extension\n", "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, create a Kaggle account and download the data from https://www.kaggle.com/competitions/m5-forecasting-accuracy/data.\n", "\n", "If you're running this notebook in Colab, manually upload the necessary files (`calendar.csv`, `sales_train_validation.csv`, `sell_prices.csv`) to Colab by clicking the `Files` tab on the left sidebar and selecting `Upload`. Move the files into a new folder named `m5`. Once uploaded, the notebook will be able to read and process the data.\n", "\n", "If you're running this on your local machine, you can also use Kaggle's [command-line tool](https://www.kaggle.com/docs/api) and run the following from the command line:\n", "\n", "```bash\n", "cd ./m5 # path to folder where you want to save the data\n", "kaggle competitions download -c m5-forecasting-accuracy\n", "```\n", "\n", "Use your favorite tool to unzip the archive. In Linux/macOS,\n", "\n", "```bash\n", "unzip m5-forecasting-accuracy.zip\n", "```\n", "\n", "We'll be using the following files: `calendar.csv`, `sales_train_validation.csv`, and `sell_prices.csv`.\n", "\n", "\n", "\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Load the data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "data_path = Path(\"./m5\") # change with path where you extracted the data archive\n", "\n", "calendar_df = pd.read_csv(data_path / \"calendar.csv\")\n", "sales_train_df = pd.read_csv(data_path / \"sales_train_validation.csv\")\n", "sell_prices_df = pd.read_csv(data_path / \"sell_prices.csv\")\n", "\n", "# add explicit columns for the day, month, year for ease of processing\n", "calendar_df[\"date\"] = pd.to_datetime(calendar_df[\"date\"])\n", "calendar_df[\"day\"] = calendar_df[\"date\"].dt.day\n", "calendar_df[\"month\"] = calendar_df[\"date\"].dt.month\n", "calendar_df[\"year\"] = calendar_df[\"date\"].dt.year" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The data\n", "\n", "### Preprocessing\n", "`sell_prices_df` contains the prices of each item in each store at a given time. The `wm_yr_wk` is a unique identifier for the time." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
store_iditem_idwm_yr_wksell_price
0CA_1HOBBIES_1_001113259.58
1CA_1HOBBIES_1_001113269.58
2CA_1HOBBIES_1_001113278.26
3CA_1HOBBIES_1_001113288.26
4CA_1HOBBIES_1_001113298.26
\n", "
" ], "text/plain": [ " store_id item_id wm_yr_wk sell_price\n", "0 CA_1 HOBBIES_1_001 11325 9.58\n", "1 CA_1 HOBBIES_1_001 11326 9.58\n", "2 CA_1 HOBBIES_1_001 11327 8.26\n", "3 CA_1 HOBBIES_1_001 11328 8.26\n", "4 CA_1 HOBBIES_1_001 11329 8.26" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sell_prices_df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`calendar_df` contains information about the dates on which the products were sold." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
datewm_yr_wkweekdaywdaymonthyeardevent_name_1event_type_1event_name_2event_type_2snap_CAsnap_TXsnap_WIday
02011-01-2911101Saturday112011d_1NaNNaNNaNNaN00029
12011-01-3011101Sunday212011d_2NaNNaNNaNNaN00030
22011-01-3111101Monday312011d_3NaNNaNNaNNaN00031
32011-02-0111101Tuesday422011d_4NaNNaNNaNNaN1101
42011-02-0211101Wednesday522011d_5NaNNaNNaNNaN1012
\n", "
" ], "text/plain": [ " date wm_yr_wk weekday wday month year d event_name_1 \\\n", "0 2011-01-29 11101 Saturday 1 1 2011 d_1 NaN \n", "1 2011-01-30 11101 Sunday 2 1 2011 d_2 NaN \n", "2 2011-01-31 11101 Monday 3 1 2011 d_3 NaN \n", "3 2011-02-01 11101 Tuesday 4 2 2011 d_4 NaN \n", "4 2011-02-02 11101 Wednesday 5 2 2011 d_5 NaN \n", "\n", " event_type_1 event_name_2 event_type_2 snap_CA snap_TX snap_WI day \n", "0 NaN NaN NaN 0 0 0 29 \n", "1 NaN NaN NaN 0 0 0 30 \n", "2 NaN NaN NaN 0 0 0 31 \n", "3 NaN NaN NaN 1 1 0 1 \n", "4 NaN NaN NaN 1 0 1 2 " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "calendar_df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`sales_train_df` contains the number of units sold for an item in each department and store. The sales are grouped by day: for example, the `d_1907` column has the number of units sold on the 1907-th day." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
iditem_iddept_idcat_idstore_idstate_idd_1d_2d_3d_4...d_1904d_1905d_1906d_1907d_1908d_1909d_1910d_1911d_1912d_1913
0HOBBIES_1_001_CA_1_validationHOBBIES_1_001HOBBIES_1HOBBIESCA_1CA0000...1301113011
1HOBBIES_1_002_CA_1_validationHOBBIES_1_002HOBBIES_1HOBBIESCA_1CA0000...0000010000
2HOBBIES_1_003_CA_1_validationHOBBIES_1_003HOBBIES_1HOBBIESCA_1CA0000...2121110111
3HOBBIES_1_004_CA_1_validationHOBBIES_1_004HOBBIES_1HOBBIESCA_1CA0000...1054101372
4HOBBIES_1_005_CA_1_validationHOBBIES_1_005HOBBIES_1HOBBIESCA_1CA0000...2110112224
\n", "

5 rows × 1919 columns

\n", "
" ], "text/plain": [ " id item_id dept_id cat_id store_id \\\n", "0 HOBBIES_1_001_CA_1_validation HOBBIES_1_001 HOBBIES_1 HOBBIES CA_1 \n", "1 HOBBIES_1_002_CA_1_validation HOBBIES_1_002 HOBBIES_1 HOBBIES CA_1 \n", "2 HOBBIES_1_003_CA_1_validation HOBBIES_1_003 HOBBIES_1 HOBBIES CA_1 \n", "3 HOBBIES_1_004_CA_1_validation HOBBIES_1_004 HOBBIES_1 HOBBIES CA_1 \n", "4 HOBBIES_1_005_CA_1_validation HOBBIES_1_005 HOBBIES_1 HOBBIES CA_1 \n", "\n", " state_id d_1 d_2 d_3 d_4 ... d_1904 d_1905 d_1906 d_1907 d_1908 \\\n", "0 CA 0 0 0 0 ... 1 3 0 1 1 \n", "1 CA 0 0 0 0 ... 0 0 0 0 0 \n", "2 CA 0 0 0 0 ... 2 1 2 1 1 \n", "3 CA 0 0 0 0 ... 1 0 5 4 1 \n", "4 CA 0 0 0 0 ... 2 1 1 0 1 \n", "\n", " d_1909 d_1910 d_1911 d_1912 d_1913 \n", "0 1 3 0 1 1 \n", "1 1 0 0 0 0 \n", "2 1 0 1 1 1 \n", "3 0 1 3 7 2 \n", "4 1 2 2 2 4 \n", "\n", "[5 rows x 1919 columns]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sales_train_df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To align the sales data with the other DataFrames, we convert `sales_train_df` to a long format. We collapse the daily sales columns `d_{i}` into a single `sales` column, with an additional `day` column indicating the day corresponding to each sales entry." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
item_iddept_idcat_idstore_idstate_iddsales
0HOBBIES_1_001HOBBIES_1HOBBIESCA_1CAd_10
1HOBBIES_1_001HOBBIES_1HOBBIESCA_1CAd_20
2HOBBIES_1_001HOBBIES_1HOBBIESCA_1CAd_30
3HOBBIES_1_001HOBBIES_1HOBBIESCA_1CAd_40
4HOBBIES_1_001HOBBIES_1HOBBIESCA_1CAd_50
\n", "
" ], "text/plain": [ " item_id dept_id cat_id store_id state_id d sales\n", "0 HOBBIES_1_001 HOBBIES_1 HOBBIES CA_1 CA d_1 0\n", "1 HOBBIES_1_001 HOBBIES_1 HOBBIES CA_1 CA d_2 0\n", "2 HOBBIES_1_001 HOBBIES_1 HOBBIES CA_1 CA d_3 0\n", "3 HOBBIES_1_001 HOBBIES_1 HOBBIES CA_1 CA d_4 0\n", "4 HOBBIES_1_001 HOBBIES_1 HOBBIES CA_1 CA d_5 0" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def convert_sales_data_from_wide_to_long(sales_df_wide):\n", " index_vars = [\"item_id\", \"dept_id\", \"cat_id\", \"store_id\", \"state_id\"]\n", " sales_df_long = pd.wide_to_long(\n", " sales_df_wide.iloc[:100, 1:],\n", " i=index_vars,\n", " j=\"day\",\n", " stubnames=[\"d\"],\n", " sep=\"_\",\n", " ).reset_index()\n", "\n", " sales_df_long = sales_df_long.rename(columns={\"d\": \"sales\", \"day\": \"d\"})\n", "\n", " sales_df_long[\"d\"] = \"d_\" + sales_df_long[\"d\"].astype(\n", " \"str\"\n", " ) # restore \"d_{i}\" format for day\n", " return sales_df_long\n", "\n", "\n", "sales_train_df_long = convert_sales_data_from_wide_to_long(sales_train_df)\n", "sales_train_df_long.head()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 1.0, 'Number of sales over the entire timespan')" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.hist(\n", " sales_train_df_long[\"sales\"],\n", " bins=np.arange(0, 10 + 1.5) - 0.5,\n", " range=[0, 10],\n", " density=True,\n", ")\n", "plt.xticks(range(10))\n", "plt.ylabel(\"Frequency of number of sales\")\n", "plt.title(\"Number of sales over the entire timespan\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train and test sets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The dataset comprises sales data of 100 items over 1,913 days. For simplicity, we select the data from the first 365 days and discard the rest." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "n_items = 100\n", "n_days = 1913\n" ] } ], "source": [ "print(f\"n_items = {len(sales_train_df_long['item_id'].unique())}\")\n", "print(f\"n_days = {len(sales_train_df_long['d'].unique())}\")\n", "\n", "sales_train_df_long[\"day_number\"] = sales_train_df_long[\"d\"].str.extract(\"(\\d+)\").astype(int)\n", "data = sales_train_df_long[sales_train_df_long[\"day_number\"] <= 365].copy()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We compute the lags of the previous 30 days and merge the sales, calendar, and price data together." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
item_iddept_idcat_idstore_idstate_iddsalesday_numbersales_lag_1sales_lag_2...yearevent_name_1event_type_1event_name_2event_type_2snap_CAsnap_TXsnap_WIdaysell_price
0HOBBIES_1_002HOBBIES_1HOBBIESCA_1CAd_14101410.00.0...2011NaNNaNNaNNaN000183.97
1HOBBIES_1_002HOBBIES_1HOBBIESCA_1CAd_14201420.00.0...2011Father's dayCulturalNaNNaN000193.97
2HOBBIES_1_002HOBBIES_1HOBBIESCA_1CAd_14301430.00.0...2011NaNNaNNaNNaN000203.97
3HOBBIES_1_002HOBBIES_1HOBBIESCA_1CAd_14411440.00.0...2011NaNNaNNaNNaN000213.97
4HOBBIES_1_002HOBBIES_1HOBBIESCA_1CAd_14501451.00.0...2011NaNNaNNaNNaN000223.97
\n", "

5 rows × 53 columns

\n", "
" ], "text/plain": [ " item_id dept_id cat_id store_id state_id d sales \\\n", "0 HOBBIES_1_002 HOBBIES_1 HOBBIES CA_1 CA d_141 0 \n", "1 HOBBIES_1_002 HOBBIES_1 HOBBIES CA_1 CA d_142 0 \n", "2 HOBBIES_1_002 HOBBIES_1 HOBBIES CA_1 CA d_143 0 \n", "3 HOBBIES_1_002 HOBBIES_1 HOBBIES CA_1 CA d_144 1 \n", "4 HOBBIES_1_002 HOBBIES_1 HOBBIES CA_1 CA d_145 0 \n", "\n", " day_number sales_lag_1 sales_lag_2 ... year event_name_1 \\\n", "0 141 0.0 0.0 ... 2011 NaN \n", "1 142 0.0 0.0 ... 2011 Father's day \n", "2 143 0.0 0.0 ... 2011 NaN \n", "3 144 0.0 0.0 ... 2011 NaN \n", "4 145 1.0 0.0 ... 2011 NaN \n", "\n", " event_type_1 event_name_2 event_type_2 snap_CA snap_TX snap_WI day \\\n", "0 NaN NaN NaN 0 0 0 18 \n", "1 Cultural NaN NaN 0 0 0 19 \n", "2 NaN NaN NaN 0 0 0 20 \n", "3 NaN NaN NaN 0 0 0 21 \n", "4 NaN NaN NaN 0 0 0 22 \n", "\n", " sell_price \n", "0 3.97 \n", "1 3.97 \n", "2 3.97 \n", "3 3.97 \n", "4 3.97 \n", "\n", "[5 rows x 53 columns]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n_lags = 30\n", "\n", "# sort data before computing lags\n", "data_index_vars = [\"item_id\", \"dept_id\", \"cat_id\", \"store_id\", \"state_id\"]\n", "data.sort_values(data_index_vars + [\"day_number\"], inplace=True)\n", "\n", "for lag in range(1, n_lags + 1):\n", " data[f\"sales_lag_{lag}\"] = data.groupby(by=data_index_vars)[\"sales\"].shift(lag)\n", "\n", "data = data.merge(calendar_df).merge(sell_prices_df)\n", "\n", "data.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Treeffuser can handle **categorical columns**, but the dtype of those columns must be set to `category` in the DataFrame." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "categorical_columns = [\n", " \"item_id\",\n", " \"dept_id\",\n", " \"cat_id\",\n", " \"store_id\",\n", " \"state_id\",\n", " \"d\",\n", " \"wm_yr_wk\",\n", " \"weekday\",\n", " \"event_name_1\",\n", " \"event_type_1\",\n", " \"event_name_2\",\n", " \"event_type_2\",\n", " \"snap_CA\",\n", " \"snap_TX\",\n", " \"snap_WI\",\n", "]\n", "data[categorical_columns] = data[categorical_columns].astype(\"category\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, for each item, we take the first 300 days as train data and use the remaining 65 data as test data for evaluation." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(15216, 50)\n", "(3891, 50)\n" ] } ], "source": [ "is_train = data[\"day_number\"] <= 300\n", "\n", "y_name = \"sales\"\n", "x_names = [\n", " name for name in data.columns if name != y_name and name not in [\"day_number\", \"date\"]\n", "]\n", "\n", "X_train, y_train, dates_train = (\n", " data[is_train][x_names],\n", " data[is_train][y_name],\n", " data[is_train][\"date\"],\n", ")\n", "X_test, y_test, dates_test = (\n", " data[~is_train][x_names],\n", " data[~is_train][y_name],\n", " data[~is_train][\"date\"],\n", ")\n", "\n", "print(X_train.shape)\n", "print(X_test.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Probabilistic predictions with Treeffuser\n", "\n", "We regress the sales on the following covariates." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "item_id, dept_id, cat_id, store_id, state_id, d, sales_lag_1, sales_lag_2, sales_lag_3, sales_lag_4, sales_lag_5, sales_lag_6, sales_lag_7, sales_lag_8, sales_lag_9, sales_lag_10, sales_lag_11, sales_lag_12, sales_lag_13, sales_lag_14, sales_lag_15, sales_lag_16, sales_lag_17, sales_lag_18, sales_lag_19, sales_lag_20, sales_lag_21, sales_lag_22, sales_lag_23, sales_lag_24, sales_lag_25, sales_lag_26, sales_lag_27, sales_lag_28, sales_lag_29, sales_lag_30, wm_yr_wk, weekday, wday, month, year, event_name_1, event_type_1, event_name_2, event_type_2, snap_CA, snap_TX, snap_WI, day, sell_price\n" ] } ], "source": [ "print(\", \".join(map(str, X_train.columns)))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[LightGBM] [Warning] Met negative value in categorical features, will convert it to NaN\n", "[LightGBM] [Warning] Met negative value in categorical features, will convert it to NaN\n", "[LightGBM] [Warning] Met negative value in categorical features, will convert it to NaN\n", "[LightGBM] [Warning] Met negative value in categorical features, will convert it to NaN\n", "[LightGBM] [Warning] Categorical features with more bins than the configured maximum bin number found.\n", "[LightGBM] [Warning] For categorical features, max_bin and max_bin_by_feature may be ignored with a large number of categories.\n" ] }, { "data": { "text/html": [ "
Treeffuser(extra_lightgbm_params={}, seed=0)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "Treeffuser(extra_lightgbm_params={}, seed=0)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = Treeffuser(seed=0)\n", "model.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 100/100 [05:55<00:00, 3.55s/it]\n" ] } ], "source": [ "y_test_samples = model.sample(X_test, n_samples=100, seed=0, n_steps=50, verbose=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Newsvendor model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We illustrate the practical relevance of accurate probabilistic predictions with an application to inventory management, using the newsvendor model \\citep{arrow1951optimal}. \n", "\n", "Assume that every day we decide how many units $q$ of an item to buy. \n", "We buy at a cost $c$ and sell at a price $p$. \n", "However, the demand $y$ is random, introducing uncertainty in our decision. \n", "The goal is to maximize the expected profit:\n", "$$\\max_{q} p~\\mathbb{E}\\left[\\min(q, y)\\right] - c q.$$\n", "The optimal solution to the newsvendor problem is to buy $q = F^{-1}\\left( \\frac{p-c}{p} \\right)$ units, where $F^{-1}$ is the quantile function of the distribution of $y$. \n", "\n", "Using Treeffuser, we can compute the quantiles from the samples and forecast the optimal quantity of units to buy.\n", "\n", "To compute profits, we use the observed prices, assume a margin of $50\\%$ over all products, and assume the actual number of sales of an item correspond to the demand of this item. We let Treeffuser, learn the conditional distribution of the demand of each item, estimate their quantiles, and thus determine the optimal quantity to buy. \n", "\n", "We use the held-out data to compute the profit made if Treeffuser was used to forecast the demand of each item and to manage the inventory of each item." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "def newsvendor_utility(y_true, quantity_ordered, prices, stocking_cost):\n", " \"\"\"\n", " The newsvendor utility function with stock q, demand y, selling price p, stocking cost c is given by\n", " $$ U(y, q, p, c) = p * min(y, q) - c * q $$\n", " \"\"\"\n", " utility = prices * np.minimum(y_true, quantity_ordered) - stocking_cost * quantity_ordered\n", " return utility\n", "\n", "\n", "def newsvendor_optimal_quantity(y_samples, prices, stocking_cost):\n", " \"\"\"\n", " Returns the optimal quantity to order for the newsvendor problem.\n", "\n", " It is given theoeretically by:\n", " $$ q* = argmax_{q} E[U(y, q, p, c)] $$\n", " which has a closed form solution,\n", " $$ q* = F^{-1}( (p - c) / p) $$\n", " where F is the CDF of the demand distribution\n", " \"\"\"\n", " # compute the target quantiles (p - c) / p\n", " target_quantiles = (prices - stocking_cost) / prices\n", " target_quantiles = np.maximum(target_quantiles, 0.0)\n", "\n", " # compute the empirical quantities corresponding to the target quantiles\n", " res = []\n", " for i in range(y_samples.shape[1]):\n", " optimal_quantities = np.quantile(y_samples[:, i], target_quantiles[i])\n", " res.append(optimal_quantities)\n", " optimal_quantities = np.array(res)\n", " return optimal_quantities" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "185.07666666666668" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# we don't know the profit margin of each item, so we assume it is 50%.\n", "profit_margin = 0.5\n", "\n", "prices = X_test[\"sell_price\"].values\n", "stocking_cost = prices / (1 + profit_margin)\n", "\n", "# compute optimal quantities\n", "optimal_quantities = newsvendor_optimal_quantity(y_test_samples, prices, stocking_cost)\n", "\n", "# Treeffuser models continuous responses, hence we cast the predicted quantities into int\n", "optimal_quantities = optimal_quantities.astype(int)\n", "\n", "profit = newsvendor_utility(y_test, optimal_quantities, prices, stocking_cost)\n", "profit.sum()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We visualize the cumulative profit, the average demand and inventory over time in the plot below." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " date profit avg_inventory_weighted avg_demand_weighted\n", "0 2011-11-25 -2.213333 0.039476 0.784104\n", "1 2011-11-26 1.600000 0.033068 0.933309\n", "2 2011-11-27 0.053333 0.046308 0.796895\n", "3 2011-11-28 5.873333 0.067636 0.789166\n", "4 2011-11-29 1.580000 0.041286 0.913090\n", ".. ... ... ... ...\n", "60 2012-01-24 6.336667 0.138240 0.661504\n", "61 2012-01-25 7.876667 0.078847 0.530883\n", "62 2012-01-26 -3.493333 0.051774 0.894360\n", "63 2012-01-27 -0.710000 0.102579 1.093555\n", "64 2012-01-28 11.890000 0.116262 1.258714\n", "\n", "[65 rows x 4 columns]\n" ] } ], "source": [ "performance_data = pd.DataFrame(\n", " {\n", " \"date\": dates_test,\n", " \"profit\": profit,\n", " \"demand\": y_test,\n", " \"inventory\": optimal_quantities,\n", " \"price\": prices,\n", " }\n", ")\n", "\n", "# for each day, compute average demand and inventory weighted by price\n", "daily_summary = (\n", " performance_data.groupby(\"date\")\n", " .agg(\n", " profit=(\"profit\", \"sum\"),\n", " avg_inventory_weighted=(\n", " \"inventory\",\n", " lambda x: (x * performance_data.loc[x.index, \"price\"]).sum()\n", " / performance_data.loc[x.index, \"price\"].sum(),\n", " ),\n", " avg_demand_weighted=(\n", " \"demand\",\n", " lambda x: (x * performance_data.loc[x.index, \"price\"]).sum()\n", " / performance_data.loc[x.index, \"price\"].sum(),\n", " ),\n", " )\n", " .reset_index()\n", ")\n", "\n", "print(daily_summary)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# dictionary to store color, alpha, linewidth, and linestyle for each line\n", "line_styles = {\n", " \"cumulative_profit\": {\"color\": \"teal\", \"alpha\": 1, \"linewidth\": 2.5, \"linestyle\": \"-\"},\n", " \"inventory\": {\n", " \"color\": \"blue\",\n", " \"alpha\": 0.5,\n", " \"linewidth\": 1.5,\n", " \"linestyle\": \"--\",\n", " },\n", " \"demand\": {\n", " \"color\": \"orange\",\n", " \"alpha\": 0.5,\n", " \"linewidth\": 1.5,\n", " \"linestyle\": \"--\",\n", " },\n", "}\n", "\n", "# create figure\n", "fig, ax1 = plt.subplots(figsize=(10, 6))\n", "\n", "# define x-axis\n", "dates = pd.to_datetime(daily_summary[\"date\"])\n", "\n", "# plot cumulative profit\n", "ax1.plot(\n", " dates,\n", " daily_summary[\"profit\"].cumsum(),\n", " **line_styles[\"cumulative_profit\"],\n", " label=\"Cumulative Profit\",\n", ")\n", "ax1.set_xlabel(\"Date\", fontsize=12)\n", "ax1.set_ylabel(\"Cumulative Profit ($)\", fontsize=12)\n", "ax1.grid(True, linestyle=\"--\", alpha=0.6)\n", "\n", "# create second y-axis for price-weighted inventory and demand\n", "ax2 = ax1.twinx()\n", "ax2.plot(\n", " dates,\n", " daily_summary[\"avg_inventory_weighted\"],\n", " **line_styles[\"inventory\"],\n", " label=\"Avg Inventory (Price Weighted)\",\n", ")\n", "ax2.plot(\n", " dates,\n", " daily_summary[\"avg_demand_weighted\"],\n", " **line_styles[\"demand\"],\n", " label=\"Avg Demand (Price Weighted)\",\n", ")\n", "ax2.set_ylabel(\"Avg Inventory and Demand (Units)\", fontsize=12)\n", "\n", "# combine all legends into one\n", "lines1, labels1 = ax1.get_legend_handles_labels()\n", "lines2, labels2 = ax2.get_legend_handles_labels()\n", "ax1.legend(\n", " lines1 + lines2, labels1 + labels2, loc=\"upper center\", bbox_to_anchor=(0.5, 1.12), ncol=3\n", ")\n", "\n", "fig.autofmt_xdate() # rotate x-tick labels\n", "\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.6" } }, "nbformat": 4, "nbformat_minor": 4 }