Bridging known and unknown dynamics by transformer-based machine-learning inference from sparse observations

Machine Learning


Hybrid machine learning

Consider a nonlinear dynamical system described by

$$\frac{d{{{\bf{x}}}}
(3)

where \({{{\bf{x}}}}
(4)

where \(\tilde{{{{\bf{X}}}}}\in {{\mathbb{R}}}^{{L}_{s}\times D}\) is the observational data matrix of dimension Ls × D and gα( ) is the following element-wise observation function:

$${X}_{ij}^{O}={g}_{\alpha }({X}_{ij})=\left\{\begin{array}{ll}{X}_{ij},\quad &\,{{\mbox{if}}}\,\,{X}_{ij}\,{{{\rm{is}}}}\,{{{\rm{observed}}}},\\ 0,\quad &\,{\mbox{otherwise}}\,,\hfill\end{array}\right.$$

(5)

with α representing the probability of matrix element Xij being observed. In Eq. (4), Gaussian white noise of amplitude σ is present during the measurement process, where \(\Xi \sim {{{\mathcal{N}}}}(0,1)\). Our goal is utilizing machine learning to approximate the system dynamics function F( ) by another function \({{{{\bf{F}}}}}^{{\prime} }(\cdot )\), assuming that F is Lipschitz continuous with respect to x and the observation function produces sparse data: \({{{\bf{g}}}}:{{{\bf{X}}}}\to \tilde{{{{\bf{X}}}}}\). To achieve this, it is necessary to design a function \({{{\mathcal{F}}}}(\tilde{{{{\bf{X}}}}})={{{\bf{X}}}}\) that comprises implicitly \({{{{\bf{F}}}}}^{{\prime} }(\cdot )\approx {{{\bf{F}}}}(\cdot )\) so that it reconstructs the system dynamics by filling the gaps in the observation, where \({{{\mathcal{F}}}}(\tilde{{{{\bf{X}}}}})\) should have the capability of adapting to any given unknown dynamics.

Selecting an appropriate neural network architecture for reconstructing dynamics from sparse data requires meeting two fundamental requirements: (1) dynamical memory to capture long-range dependencies in the sparse data, and (2) flexibility to handle input sequences of varying lengths. Transformers32, originally developed for natural language processing, satisfy these requirements due to their basic attention structure. In particular, transformers has been widely applied and proven effective for time series analysis, such as prediction51,52,53, anomaly detection54, and classification55. Figure 6 illustrates the transformer’s main structure. The data matrix \(\tilde{{{{\bf{X}}}}}\) is first processed through a linear fully-connected layer with bias, transforming it into an Ls × N matrix. This output is then combined with a positional encoding matrix, which embeds temporal ordering information into the time series data. This projection process can be described as56:

$${{{{\bf{X}}}}}_{p}=\tilde{{{{\bf{X}}}}}{{{{\bf{W}}}}}_{p}+{{{{\bf{W}}}}}_{b}+{{{\bf{PE}}}},$$

(6)

where \({{{{\bf{W}}}}}_{p}\in {{\mathbb{R}}}^{D\times N}\) represents the fully-connected layer with the bias matrix \({{{{\bf{W}}}}}_{b}\in {{\mathbb{R}}}^{{L}_{s}\times N}\) and the position encoding matrix is \({{{\bf{PE}}}}\in {{\mathbb{R}}}^{{L}_{s}\times N}\). Since the transformer model does not inherently capture the order of the input sequence, positional encoding is necessary to provide the information about the position of each time step. For a given position \(1\le {{{\rm{pos}}}}\le {L}_{s}^{max}\) and dimension 1 ≤ d ≤ D, the encoding is given by

$${{{{\bf{PE}}}}}_{pos,2d}=\sin \left(\frac{{{{\rm{pos}}}}}{1000{0}^{2d/N}}\right),$$

(7)

$${{{{\bf{PE}}}}}_{pos,2d+1}=\cos \left(\frac{{{{\rm{pos}}}}}{1000{0}^{2d/N}}\right),$$

(8)

The projected matrix \({{{{\bf{X}}}}}_{p}\in {{\mathbb{R}}}^{{L}_{s}\times N}\) then serves as the input sequence for Nb attention blocks. Each block contains a multi-head attention layer, a residual layer (add & layer norm), and a feed-forward layer, and a second residual layer. The core of the transformer lies in the self-attention mechanism, allowing the model to weight the significance of distinct time steps. The multi-head self-attention layer is composed of several independent attention blocks. The first block has three learnable weight matrices that linearly map Xp into query Q1 and key K1 of the dimension Ls × dk and value V1 of the dimension Ls × dv:

$${{{{\bf{Q}}}}}_{1}={{{{\bf{X}}}}}_{p}{{{{\bf{W}}}}}_{{{{{\bf{Q}}}}}_{{{{\bf{1}}}}}},\quad {{{{\bf{K}}}}}_{1}={{{{\bf{X}}}}}_{p}{{{{\bf{W}}}}}_{{{{{\bf{K}}}}}_{{{{\bf{1}}}}}},\quad {{{{\bf{V}}}}}_{1}={{{{\bf{X}}}}}_{p}{{{{\bf{W}}}}}_{{{{{\bf{V}}}}}_{{{{\bf{1}}}}}},$$

(9)

where \({{{{\bf{W}}}}}_{{{{{\bf{Q}}}}}_{1}}\in {{\mathbb{R}}}^{N\times {d}_{k}}\), \({{{{\bf{W}}}}}_{{{{{\bf{K}}}}}_{1}}\in {{\mathbb{R}}}^{N\times {d}_{k}}\), and \({{{{\bf{W}}}}}_{{{{{\bf{V}}}}}_{1}}\in {{\mathbb{R}}}^{N\times {d}_{v}}\) are the trainable weight matrices, dk is the dimension of the queries and keys, and dv is the dimension of the values. A convenient choice is dk = dv = N. The attention scores between the query Q1 and the key K1 are calculated by a scaled multiplication, followed by a softmax function:

$${{{{\bf{A}}}}}_{{{{{\bf{Q}}}}}_{1},{{{{\bf{K}}}}}_{1}}={{{\rm{softmax}}}}\left(\frac{{{{{\bf{Q}}}}}_{1}{{{{\bf{K}}}}}_{1}^{{\mathsf{T}}}}{\sqrt{{d}_{k}}}\right),$$

(10)

where \({{{{\bf{A}}}}}_{{{{{\bf{Q}}}}}_{1},{{{{\bf{K}}}}}_{1}}\in {{\mathbb{R}}}^{{L}_{s}\times {L}_{s}}\). The softmax function normalizes the data with \({{{\rm{softmax}}}}({x}_{i})=\exp ({x}_{i})/{\sum }_{j}\exp ({x}_{j})\), and the \(\sqrt{{d}_{k}}\) factor mitigates the enlargement of standard deviation due to matrix multiplication. For the first head (in the first block), the attention matrix is computed as a dot product between \({{{{\bf{A}}}}}_{{{{{\bf{Q}}}}}_{1},{{{{\bf{K}}}}}_{1}}\) and V1:

$${{{{\bf{O}}}}}_{11}=\, {{{\rm{Attention}}}}({{{{\bf{Q}}}}}_{1},{{{{\bf{K}}}}}_{1},{{{{\bf{V}}}}}_{1}),\\=\, {{{{\bf{A}}}}}_{{{{{\bf{Q}}}}}_{1},{{{{\bf{K}}}}}_{1}}{{{{\bf{V}}}}}_{1}={{{\rm{softmax}}}}\left(\frac{{{{{\bf{Q}}}}}_{1}{{{{\bf{K}}}}}_{1}^{{\mathsf{T}}}}{\sqrt{{d}_{k}}}\right){{{{\bf{V}}}}}_{1},$$

(11)

where \({{{{\bf{O}}}}}_{11}\in {{\mathbb{R}}}^{{L}_{s}\times {d}_{v}}\). The transformer employs multiple (h) attention heads to capture information from different subspaces. The resulting attention heads O1i (i = 1, …, h) are concatenated and mapped into a sequence \({{{{\bf{O}}}}}_{1}\in {{\mathbb{R}}}^{{L}_{s}\times N}\), described as:

$${{{{\bf{O}}}}}_{1}={{{\mathcal{C}}}}({{{{\bf{O}}}}}_{11},{{{{\bf{O}}}}}_{12},\cdots {{{{\bf{O}}}}}_{1\,h}){{{{\bf{W}}}}}_{o1},$$

(12)

where \({{{\mathcal{C}}}}\) is the concatenation operation, h is the number of heads, and \({{{{\bf{W}}}}}_{o1}\in {{\mathbb{R}}}^{h{d}_{v}\times N}\) is an additional matrix for linear transformation for performance enhancement. The output of the attention layer undergoes a residual connection and layer normalization, producing XR1 as follows:

$${{{{\bf{X}}}}}_{R1}={{{\rm{LayerNorm}}}}({{{{\bf{X}}}}}_{{{{\bf{p}}}}}+{{{\rm{Dropout}}}}({{{{\bf{O}}}}}_{1}))$$

(13)

A feed-forward layer then processes this data matrix, generating output \({{{{\bf{X}}}}}_{F1}\in {{\mathbb{R}}}^{{L}_{s}\times N}\) as:

$${{{{\bf{X}}}}}_{F1}=\max \left(0,{{{{\bf{X}}}}}_{R1}{{{{\bf{W}}}}}_{{F}_{a}}+{{{{\bf{b}}}}}_{a}\right){{{{\bf{W}}}}}_{{F}_{b}}+{{{{\bf{b}}}}}_{b},$$

(14)

where \({{{{\bf{W}}}}}_{{F}_{a}}\in {{\mathbb{R}}}^{N\times {d}_{f}}\), \({{{{\bf{W}}}}}_{{F}_{b}}\in {{\mathbb{R}}}^{{d}_{f}\times N}\), ba and bb are biases, and \(\max (0,\cdot )\) denotes a ReLU activation function. This output is again subjected to a residual connection and layer normalization.

Fig. 6: Transformer architecture.
figure 6

The transformer receives the sparse and random observation as the input and generates the reconstructed output. Nb refers to the number of transformer blocks. See text for a detailed mathematical description.

The output of the first block operation is used as the input to the second block. The same procedure is repeated for each of the remaining Nb–1 blocks. The final output passes through a feed-forward layer to generate the prediction. Overall, the whole process can be represented as \({{{\bf{Y}}}}={{{\mathcal{F}}}}(\tilde{{{{\bf{X}}}}})\).

The second component of our hybrid machine-learning framework is reservoir computing, which takes the output of the transformer as the input to reconstruct the long-term climate or attractor of the target system. A detailed description of reservoir computing used in this context and its hyperparameters optimization are presented in Supplementary Notes 4 and 5.

Machine learning loss

To evaluate the reliability of the generated output, we minimize a combined loss function with two components: (1) a mean squared error (MSE) loss that measures absolute error between the output and ground truth, and (2) a smoothness loss that ensures the output maintains appropriate continuity. The loss function is given by

$${{{\mathcal{L}}}}={\alpha }_{1}{{{{\mathcal{L}}}}}_{{{{\rm{mse}}}}}+{\alpha }_{2}{{{{\mathcal{L}}}}}_{{{{\rm{smooth}}}}},$$

(15)

where α1 and α2 are scalar weights controlling the trade-off between the two loss terms. The first component \({{{{\mathcal{L}}}}}_{{{{\rm{mse}}}}}\) measures the absolute error between the predictions and the ground truth:

$${{{{\mathcal{L}}}}}_{{{{\rm{mse}}}}}=\frac{1}{n} {\sum}_{i=1}^{n}{({y}_{i}-{\hat{y}}_{i})}^{2},$$

(16)

with n being the total number of data points, yi and \({\hat{y}}_{i}\) denoting the ground truth and predicted value at time point i, respectively. The second component \({{{{\mathcal{L}}}}}_{{{{\rm{smooth}}}}}\) of the loss function consists of two terms: Laplacian regularization and total variation regularization, which penalize the second-order differences and absolute differences, respectively, between consecutive predictions. The two terms are given by:

$${{{{\mathcal{L}}}}}_{{{{\rm{laplacian}}}}}=\frac{1}{n-2} {\sum}_{i=2}^{n-1}{({\hat{y}}_{i-1}+{\hat{y}}_{i+1}-2\hat{y})}^{2},$$

(17)

and

$${{{{\mathcal{L}}}}}_{{{{\rm{tv}}}}}=\frac{1}{n-1} {\sum}_{i=1}^{n-1}| {\hat{y}}_{i}-{\hat{y}}_{i+1}| .$$

(18)

We assign the same weights to the two penalties, so the final combined loss function to be minimized is

$${{{\mathcal{L}}}}={{{{\mathcal{L}}}}}_{{{{\rm{mse}}}}}+{\alpha }_{s}({{{{\mathcal{L}}}}}_{{{{\rm{laplacian}}}}}+{{{{\mathcal{L}}}}}_{{{{\rm{tv}}}}}).$$

(19)

We set αs = 0.1. It is worth noting that the smoothness penalty is a crucial hyperparameter that should be carefully selected. Excessive smoothness leads the model to learn overly coarse-grained dynamics, while absence of a smoothness penalty causes the reconstructed curves to exhibit poor smoothness (Supplementary Note 5).

Computational setting

Unless otherwise stated, the following computational settings for machine learning are used. Given a target system, time series are generated numerically by integrating the system with time step dt = 0.01. The initial states of both the dynamical process and the neural network are randomly set from a uniform distribution. An initial phase of the time series is removed to ensure that the trajectory has reached the attractor. The training and testing data are obtained by sampling the time series at the interval Δs chosen to ensure an acceptable generation. Specifically, for the chaotic food-chain, Lorenz and Lotka–Volterra systems, we set Δs = 1, Δs = 0.02, and Δs = 1 respectively, corresponding to approximately 1 over 30 ~ 50 cycles of oscillation. A similar procedure is also applied to other synthetic chaotic systems (See Table S3 for Δs values for each system). The time series data are preprocessed by using min-max normalization so that they are in the range [0,1]. The complete data length for each system is 1,500,000 (about 30,000 cycles of oscillation), which is divided into segments with randomly chosen sequence lengths Ls and sparsity Sr. For the transformer, we use a maximum sequence length of 3000 (corresponding to about 60 cycles of oscillation)—the limitation of input time series length. We apply Bayesian optimization57 and a random search algorithm58 to systematically explore and identify the optimal set of various hyperparameters. Two chaotic Sprott systems—Sprott0 and Sprott1—are used as validation systems to find the optimal hyperparameters and to train the final model weights, ensuring no data leakage from the testing systems. The optimized hyperparameters for the transformer are listed in Table 1. All simulations are run using Python on computers with six RTX A6000 NVIDIA GPUs. A single training run of our framework typically takes about 30 min using one of the GPUs.

Table 1 Optimal transformer hyperparameter values

Prediction stability

The prediction stability describe the probability that the transformer generates stable predictions, which is defined as the probability that the MSE is below a predefined stable threshold MSEc:

$${R}_{s}({{{{\rm{MSE}}}}}_{{{{\rm{c}}}}})=\frac{1}{n} {\sum}_{i=1}^{n}[{{{\rm{MSE}}}} < {{{{\rm{MSE}}}}}_{{{{\rm{c}}}}}],$$

(20)

where n is the number of iterations and [ ] = 1 if the statement inside is true and zero otherwise.

Deviation value

For a three-dimensional target system, we divide the three-dimensional phase space into a uniform cubic lattice with the cell size Δ = 0.05 and count the number of trajectory points in each cell, for both the predicted and true attractors in a fixed time interval. The DV measure is defined as21

$${{{\rm{DV}}}}\equiv {\sum}_{i=1}^{{m}_{x}} {\sum}_{j=1}^{{m}_{y}} {\sum}_{k=1}^{{m}_{z}}\sqrt{{\left({f}_{i,j,k}-{\hat{f}}_{i,j,k}\right)}^{2}},$$

(21)

where mx, my, and mz are the total numbers of cells in the x, y, and z directions, respectively, fi,j,k and \({\hat{f}}_{i,j,k}\) are the frequencies of visit to the cell (ijk) by the predicted and true trajectories, respectively. If the predicted trajectory leaves the phase space boundary, we count it as if it has landed in the boundary cells where the true trajectory never goes.

Noise implementation

We study how two types of noise affect the dynamics reconstruction in this work: multiplicative and additive noise. We use normally distributed stochastic processes of zero mean and standard deviation σ, while the former perturbs the observational points x to x + x ξ after normalization and the latter perturbs x to x + ξ. Note that multiplicative (demographic) noise is common in ecological systems.

Reporting summary

Further information on research design is available in the Nature Portfolio Reporting Summary linked to this article.



Source link