Artificial Intelligence πŸ€–
Recurrent Neural Networks (RNNs)
Temporal Fusion Transformers

Temporal Fusion Transformers

The Temporal Fusion Transformer (TFT) is an advanced model designed for forecasting across multiple time horizons and incorporating a variety of multivariate time series inputs. This model stands out for its capability to provide interpretable insights alongside high-accuracy forecasts.

The TFT ingests several different types of input data:

  1. Past target values yy: Observations from the past within a defined window kk.
  2. Time-varying inputs: These include inputs that are known ahead of time xx and those that are not (zz).
  3. Static covariates ss: Metadata that does not change over time and provides context for the time series.

TFT produces forecasts by outputting quantiles, offering a range of possible future values, rather than a single point prediction. The quantile forecast for the ii-th time series at time tt for horizon Ο„\tau is given by:

y^i(q,t,Ο„)=fq(Ο„,yi,tβˆ’k:t,zi,tβˆ’k:t,xi,tβˆ’k:t+Ο„,si)\hat{y}_i(q, t, \tau) = f_{q} \left( \tau, y_{i, t-k:t}, \mathbf{z}_{i, t-k:t}, \mathbf{x}_{i, t-k:t+\tau}, \mathbf{s}_i \right)

Here, fqf_{q} is the function that predicts the qq-th quantile, and the other terms represent the historical target values, unknown inputs, known inputs up to t+Ο„t+\tau, and static covariates respectively.

TFT Model Overview

An overview of the TFT model architecture from the paper:

TFT architecture

Gated Residual Networks (GRN)

The TFT employs GRNs at various layers to adaptively regulate the depth and complexity of the neural network, enhancing its generalization capabilities. They ensure its flexibility by introducing skip/residual connections which feed the output of a particular layer to upper layers in the network that are not directly adjacent.

This way, the model can learn that some non-linear processing layers are unnecessary and skip them. GRN improve the generalization capabilities of the model across different application scenarios (e.g. noisy or small datasets) and helps to significantly reduce the number of needed parameters and operations.

Gated Residual Network

Here ELU stands for Exponential Linear Unit activation function.

Static Covariate Encoders

These encoders transform static metadata into context vectors that influence the model at different stages, ensuring that static information conditions the learning of temporal patterns. The stages at which these context vectors are injected are:

  1. Temporal variable selection
  2. Local processing of temporal representations in the Sequence-to-Sequence layer
  3. Static enrichment of temporal representations

Variable Selection Networks

A separate variable selection block is implemented for each type of input (static covariates, past inputs (time-dependent known and unknown) and known future inputs). These blocks learn to weigh the importance of each input feature. This way, the subsequent Sequence-to-Sequence layer will take as input the re-weighted sums of transformed inputs for each time step. Here, transformed inputs refer to learned linear transformations of continuous features and entity embeddings of categorical ones.

The external context vector consists in the output of the static covariate encoder block. It is therefore omitted for the variable selection block of static covariates.

Alt text

Sequence-to-Sequence (Seq2Seq) Layers

Replacing the positional encoding in standard transformers, the Seq2Seq layers in TFT handle local temporal patterns, essential for time series forecasting. These layers are more adapted for time series data, as it allows to capture local temporal patterns via recurrent connections.

Context vectors are used in this block to initialize the cell state and hidden state of the first LSTM unit. They are also employed in what the authors call "static enrichment layer" to enrich the learned temporal representation from the Sequence-to-Sequence layer with static information.

Interpretable Multi-Head Attention

Classical attention mechanisms weighs the importance of values based on the relationships between keys KK and queries QQ:

Attention(Q,K,V)=Ξ±(Q,K)V\text{Attention}(Q, K, V) = \alpha(Q,K) V

where Ξ±(Q,K)\alpha(Q,K) are attention weights. A common choice for Ξ±\alpha is the scaled dot-product attention.

The original multi-attention mechanism consists of using multiple attention heads to re-weigh the values based on the relevance between keys and queries. The outputs of different heads are then combined via concatenation, as follows:

MultiHead(Q,K,V)=Concat(head1,head2,…,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(head_1, head_2, \ldots, head_h) W^O

where each individual head headi\text{head}_i is defined as:

headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW^{Q}_i, KW^{K}_i, VW^{V}_i)

In self-attention, queries, keys and values come from the same input. This allows to learn the relevance of each time step with respect to the rest of the input sequence, and therefore to capture long-range temporal dependencies. Note that at the decoder part, subsequent time steps at each decoding step are masked to avoid information leakage from future to past data points.

TFTs utilizes an interpretable form of multi-head attention that allows tracing back the importance of features across different time steps. Instead of having multiple head-specific weights for values, these are shared across all attention heads. This allows us to easily trace back the most relevant values. The outputs of all heads are then aggregated:

InterpretableMultiHead(Q,K,V)=1hβˆ‘i=1hheadiWO\text{InterpretableMultiHead}(Q, K, V) = \frac{1}{h} \sum_{i=1}^h \text{head}_iW^O

Interpretable multi-head attention

Quantile Regression

TFT uses quantile regression to generate forecasts of target y^\hat{y}, which gives a range of possible outcomes rather than a single point estimate, using the pinball loss function:

QL(y,y^,q)=qmax⁑(0,yβˆ’y^)+(1βˆ’q)max⁑(0,y^βˆ’y)QL(y, \hat{y}, q) = q \max(0, y - \hat{y}) + (1 - q) \max(0, \hat{y} - y)

Intuitively, the first term of this loss function is activated for under-estimations and is highly weighted for upper quantiles, whereas the second term is activated for over-estimations and is highly weighted for lower quantiles. This way, the optimization process is forcing the model to provide reasonable over-estimations for upper quantiles and under-estimations for lower quantiles. Notice that for the median prediction (0.50.5 quantile), optimizing the quantile loss function is equivalent to that of the MAE loss.

The TFT is then trained by minimizing an aggregate of all quantile losses across all quantile outputs. Quantile regression is very useful for high-stakes applications to have some kind of a quantification of uncertainty of predicted values at each time step.

Interpretability Features

TFT's design allows for unique interpretability features:

  • Variable Selection Blocks: Determine the global importance of input features.
    • Done by adjusting the standard multi-head attention definition to have shared weights for values across all attention heads
  • Interpretable Attention Weights: Offer insights into the relevance of past timesteps for forecasting.
    • This insight is traditionally gained using preliminary seasonality and autocorrelation analysis
    • Identify significant changes in temporal patterns. This is done by computing an average attention pattern per forecast horizon and evaluating the distance between it and attention weights at each point.

Implementations and Resources

TFT's power and interpretability make it applicable across a range of domains, with implementations available in Pytorch Forecasting and Darts:

Conclusion

TFT is a sophisticated model for time series data that delivers on multiple fronts:

  • Multi-horizon forecasting
  • Multivariate time series with heterogeneous features (support for static covariates, time varying known and unknown variables)
  • Prediction intervals to quantify uncertainty
  • Interpretablity of results

TFT is mainly able to:

  1. Capture temporal dependencies at different time scales by a combination of the LSTM Sequence-to-Sequence and the Transformer’s Self-Attention mechanism
  2. Enrich learned temporal representations with static information about measured entities.

Resources: