Distilling knowledge from graph neural networks trained on cell graphs to non-neural student models

Machine Learning


Datasets based on cell graphs and non-cell graphs

For this work, we utilized three cell graph-based datasets: one from our previous paper on tuberculosis (TB)50, another dataset from placenta histology51, and lastly, the TCGA Breast Cancer Cell Classification Dataset (BRCA-M2C)52. The TB dataset contained 44 whole slide images (WSIs) with an average size of 42,831 x 41,159 pixels at 40X magnification. The nodes were classified into acid-fast bacilli (AFB) and the nucleus of activated macrophages. The approach used to determine the cell locations and classify the cell types is detailed in our previous work50. We used 34 WSIs for training and validation, while 10 WSIs were reserved for the test set. The train and test WSIs used in this study differed from those proposed in50. The training set had 90878 nodes, the validation set had 22708 nodes, and the test set had 76316 nodes.

The placenta dataset consisted of two cell graphs constructed from two placenta histology WSIs, combined into a single graph with nine classes. We utilized the original 64-dimensional feature set provided with the dataset for our analysis. These features primarily focussed on the morphological characteristics of the cells. Our goal was to evaluate the efficacy of knowledge distillation with cell graph datasets where the cell graph features were not included in the training process. The process of feature extraction is described in51. Additionally, we followed the dataset’s original train, validation, and test split (considering only labeled nodes).

The BRCA-M2C dataset (Breast Cancer Dataset)52 provided dot annotations for multi-class cell classification in breast cancer images, including the annotated cells’ coordinates and corresponding labels. The cell extraction and labeling process can be found in52. These images were patches extracted from 1000×1000 pixels at the highest resolution and downsampled to 20x. All images were around 500×500 pixels. The cell classes included lymphocytes, breast cancer cells, and stromal cells. There were 80 image data (coordinates of the annotated cells along with their corresponding labels) under the training set, 10 image data under the validation set, and the test set consisted of 30 image data. We combined training and validation data while keeping the test data unchanged. This resulted in 19602 training nodes, 2178 validation set nodes, and 8858 test set nodes.

To determine the generalizability of our approach to non-cell graph-based datasets and in the absence of features extracted from cell graphs, we used three non-cell graph-based datasets: CoauthorCS, CoauthorPhysics and a synthetic dataset. These datasets consisted of a single graph. The CoauthorCS dataset consisted of 18,333 nodes and 163,788 edges, with nodes divided into 15 classes. A 6,805-dimensional feature vector represented each node. The training set had 12833 nodes, the validation set had 3666 nodes, and the test set had 1834 nodes. Similarly, the CoauthorPhysics dataset contained 34,493 nodes and 495,924 edges, with nodes categorized into five classes. Node features in this dataset were 8,415-dimensional vectors. The training set had 24145 nodes, the validation set had 6898 nodes, and the test set had 3450 nodes. These datasets were only used to evaluate the applicability of our approach to non-cell graph settings and were not included in ablation studies. We generated a synthetic dataset of 60,000 nodes using the preferential attachment mechanism of the Barabási-Albert model53. Seven topological features were extracted for this graph to represent its structural properties. The dataset training set contained 42,000 nodes, 12,000 nodes were present in the validation set, and 6,000 nodes were present in the test set, respectively.

Generally, datasets with a minority class proportion between 20% and 40% are considered to have mild imbalance, those with proportions from 1% to 20% are categorized as moderately imbalanced, and datasets with a minority class proportion of less than 1% are considered extremely imbalanced54. Based on this classification, TB and Breast cancer datasets had a mild imbalance. The Placenta, CoauthorCS, and Synthetic datasets demonstrated extreme class imbalance. The CoauthorPhysics dataset had a moderate imbalance.

Construction of cell graph

Edge construction in cell graphs estimates the biological likelihood that neighboring cells interact within the same structure. The edge threshold for intercellular communication is critical in cellular studies, and many investigations have aimed to determine the optimal distance for accurately modeling these interactions. Pathologists’ input provides valuable guidance to refine graph representations and ensure they accurately reflect the biological relationships between cells55. Many prior works have employed a single threshold value to map cell-cell interactions23,33, while some have experimented with varying edge thresholds, such as 60, 75, and 90 \(\mu\)m, to identify an appropriate threshold value56. In contrast, our approach uses distinct threshold values for each cell-cell pair,

In the TB dataset, nodes represent either AFBs or the nucleus of activated macrophages. Edge thresholds were based upon the length of cords of the M.tb infected cells after 72 hours of infection9 and the fact that macrophages can extend their pseudopods beyond their normal boundary (radius) to detect other cells farther away. We hypothesize that AFBs can interact with other AFBs within a distance of 150 \(\mu\)m, equivalent to 615 pixels at the magnification used in this study57. Likewise, activated macrophage nuclei may interact with both AFBs and each other if they are within 500 \(\mu\)m (2049 pixels)10. Our domain expert has thoroughly reviewed and validated these threshold values.

The adjacency matrix is computed as follows:

$$A_{i j}\left\{ \begin{array}{lc} 1 & \text{ if } Distance(u, v)

Distance denotes euclidean distance computing using the equation 1. The coordinates \((x_u,y_u)\) belongs to node ’u’ and the coordinates \((x_v,y_v)\) belons to node ’v’ in the image.

$$\begin{aligned} d(u, v)=\sqrt{\left( x_u-x_v\right) ^2+\left( y_u-y_v\right) ^2} \end{aligned}$$

(1)

The distance threshold values are tabulated in the Table 1.

Table 1 Distance thresholds.

For the placenta dataset, the authors utilized the intersection of the K-nearest neighbors (KNN)58 and Delaunay triangulation59 graphs with a k-value of 5 to generate the cell graphs. In this graph, the nodes represented cells, and the edges depicted their interactions.

For the BRCA-M2C dataset, we constructed cell graphs where nodes represent cells and edges represent interactions based on the k-nearest neighbors (KNN)58 approach. Different k-values were used for each pair of cell types to reflect the biological significance of their interactions. The values used are tabulated in the Table 2. The adjacency matrix is calculated using the Eq. 2. The chosen k values were determined based on the cohesiveness of tumor cells and the solitary nature of stromal cells in tumors. Similarly, lymphocyte interactions were assigned moderate k values to reflect their intermediate proximity during immune surveillance, whether with tumor cells or among themselves. Figure 1 illustrates the cell graphs for various datasets.

$$\begin{aligned} A[i, j]=\left\{ \begin{array}{lc} 1 & \text{ if } j \in K N N(i) \\ 0 & \text{ otherwise } \end{array}\right. \end{aligned}$$

(2)

Table 2 k-values for different types of cell interactions.
Figure 1
figure 1

Cell graphs of the TB and BRCA-M2C datasets were generated using the NetworkX library60 (version 3.4.2, https://networkx.org/). (A) Cell Graph generated for a TB image. Acid-fast bacilli (AFB) cells are shown in red, and the nucleus of activated macrophages is depicted in blue. Black edges represent interactions. (B) Cell Graph generated for a normal lung tissue, i.e., not infected. (C) Cell Graph acquired from the Vanea et al.51, licensed under Creative Commons Attribution 4.0 International License (https://creativecommons.org/licenses/by/4.0/). (D) Cell Graph generated from the BRCA-M2C dataset, where red nodes represent lymphocytes, blue nodes represent tumor cells, green nodes represent stromal cells, and gray edges denote their interactions, created using different k-values for specific cell interactions.

Are all these edges required?

While the cell graphs used in our study are generated by considering biological interactions, we acknowledge that they might not represent the optimal cell graphs. The edges in these graphs capture critical intercellular interactions. However, determining the optimal edges for such graphs remains an open research question. These interactions prove to be highly beneficial, particularly when the test set originates from a distribution different from the training set. Randomly removing edges from the cell graphs has been shown to hamper the teacher model’s performance. This, in turn, degrades the performance of the student models, as the quality of the teacher’s logits diminishes. The concept of optimal cell graphs with the right amount of connectivity to balance model complexity and performance remains an emerging area of research that requires further exploration.

Feature extraction

We tested the efficacy of our approach under different feature sets across datasets. We combined local cell graph features with morphological features for the TB dataset. For the Placenta dataset, we used only morphological features (along with inherent variations in cell appearance). For the BRCA-M2C dataset, we utilized only the local cell graph features. For the Coauthorship datasets, we did not extract additional features. Instead, we used the existing original features provided by the datasets.

TB dataset

In50, combining morphological and graph features resulted in the best results for CG-JKNN. Hence, we use this combination to train our models in this work. Table 3 denotes the extracted features; the description can be found in the paper that introduced it.

Placenta dataset

For the placenta dataset, we used the features defined in the original paper. Specifically, the node features are defined using the nucleus coordinates as node coordinates and the 64-dimensional embeddings from the penultimate layer of the cell classifier model. These features primarily encode morphological information about cells rather than cell graph structural information.

BRCA-M2C dataset

For the BRCA-M2C dataset, we extracted the local graph features from the cell graphs generated. The extracted features are listed in Table 4.

Distilling the knowledge from CG-JKNN (teacher) to tree-based ensembles (students)

Based on the CG-JKNN architecture, the teacher model is designed for node-level classification tasks. A graph is defined as \(G = (V, E)\), where V denotes the set of nodes, and each node v is associated with a d-dimensional feature vector \(x_v \in \mathbb {R}^d\). The edges E are represented by \(e_{u, v} = (u, v)\), indicating a connection between nodes u and v. The adjacency matrix \(A \in \mathbb {R}^{n \times n}\) encodes the graph structure.

Figure 2
figure 2

Architecture of the teacher model used for knowledge distillation. To obtain the temperature-scaled logits, as discussed in the ablation study, a temperature-scaling block needs to be incorporated between the logits generated by the teacher model and the input to the student models.

The architecture of our teacher model and the flow of our proposed work are depicted in Fig. 2. To train the teacher GNN, we utilize cell graphs G constructed along with their associated node features \(x_v\). During the training phase, the model learns to classify each node by predicting its label based on the provided labeled graphs. During testing, the trained GNN receives unseen cell graphs G and their associated node features \(x_v\). The model predicts the test node labels, which are then compared against the true labels in the test set to evaluate performance.

Each node’s hidden features \(h_v^{(l)} \in \mathbb {R}^d\) in the l-th layer are initialized with the input features as \(h_v^{(0)} = x_v\). The GraphSAGE layers process node representations, employing a mean aggregation function as shown in Eq. 3 to gather information from neighboring nodes. In our previous work50, we experimented with both mean and max aggregators and found the mean aggregator to achieve superior performance consistently. This also aligned with prior studies that demonstrate the effectiveness of mean aggregation in node classification tasks61,62. Therefore, we selected the mean aggregator.

$$\begin{aligned} \varvec{h_{N(v)}^{(l)}} = \operatorname {MEAN} \left( \left\{ \varvec{h_u^{(l-1)}}, \forall u \in N(v)\right\} \right) \end{aligned}$$

(3)

Here, \(h_{N(v)}^{(l)}\) represents the aggregated neighborhood representation, and \(h_u^{(l-1)}\) corresponds to the representation of neighboring node u from the previous layer. The node’s updated representation is computed using Eq. 4.

$$\begin{aligned} \varvec{h_v^{(l)}} = \sigma \left( W \cdot \left[ \varvec{h_v^{(l-1)}}, \varvec{h_{N(v)}^{(l)}} \right] \right) \end{aligned}$$

(4)

Here, W is the learnable weight matrix, and \(\sigma\) denotes the activation function (ReLU).

The “jumping knowledge representation learning” mechanism12 is incorporated to combine multi-layer node representations. This approach concatenates representations from all layers to form a comprehensive node representation (Eq. 5) instead of using only the final layer’s representation. The authors in12 explored three different aggregation mechanisms: concatenation, max-pooling, and an LSTM-based attention mechanism. Our network adopts the concatenation-based jumping knowledge mechanism for aggregating node representations.

$$\begin{aligned} \varvec{h_v^{(Concatenated)}} = \text {Concatenate} \left[ \varvec{h_v^{(1)}}, \ldots , \varvec{h_v^{(l)}} \right] \end{aligned}$$

(5)

After concatenation, the node representations are passed through a GATv2 layer63, which refines the representations using an attention mechanism. The attention coefficients \(\alpha _{vu}\) are computed as:

$$\begin{aligned} \begin{aligned} \alpha _{vu} = \operatorname {softmax}_u \Big (&\operatorname {LeakyReLU} \left( \varvec{a}^T \left[ W \varvec{h_v^{(\text {Concatenated})}} \right. \right.&\left. \left. \Vert W \varvec{h_u^{(\text {Concatenated})}} \right] \right) \Big ) \end{aligned} \end{aligned}$$

(6)

Finally, the node representations are updated as shown in Eq. 7, Later, the softmax function applied to obtain the class probabilities.

$$\begin{aligned} \varvec{h_v^{(GAT)}} = \sigma \left( \sum _{u \in \mathscr {N}(v)} \alpha _{vu} W \varvec{h_u^{(\text {Concatenated})}} \right) \end{aligned}$$

(7)

Here, \(\mathscr {N}(v)\) denotes the neighbors of node v, and \(\sigma\) is the activation function. We use a rectified linear unit (ReLU) as the activation function. Over-smoothing is a critical issue in GNNs. It arises when deep networks cause node features to converge, losing their distinctiveness. Existing approaches address this challenge using various strategies. Energetic Graph Neural Networks employ energy-based modeling64, while Graph DropConnect introduces graph-specific dropout65. Graph-coupled oscillator Networks use non-linear oscillators to modify GNN dynamics66, and residual connections improve the information flow in deep GNNs to counter over-smoothing67. For this study, we adopted the DropEdge technique68. It mitigates over-smoothing by randomly removing a proportion of edges during training. Using the edge index representation for graph connections, we experimented with various dropping rates.

Logits represent the unnormalized outputs of the model. It provides richer information compared to class probabilities. It has been shown in the literature that training the student model directly on the logits allows for more effective learning of the internal representations captured by the teacher18. This approach enables the student to mimic the teacher’s learned patterns better. Additionally, it helps avoid the information loss that typically occurs when logits are transformed into probabilities. Hence, we extract the logits before applying the softmax function for knowledge distillation and use them as labels to train the student regressor models.

In general, the KD loss69 is formulated to align the predictions of the student model with those of the teacher model by minimizing the divergence between their output distributions. This is typically achieved by leveraging the Kullback-Leibler (KL) divergence. While this approach is effective for neural network-based student models that undergo continuous updates during training, it is not directly applicable to our scenario. In our study, the student models are tree-based ensembles that do not rely on iterative gradient updates. As a result, we don’t utilize this loss function.

After training on the teacher’s logits as targets, the student models generate predictions, which are converted into probabilities using the softmax function. These probabilities are evaluated to calculate performance metrics such as accuracy and F1-score. We specifically chose non-linear models for students because the teacher logits, serving as labels, are inherently non-linear. For the student models to effectively learn from these logits, they must possess sufficient capacity (or complexity) to capture the underlying non-linear relationships embedded in the teacher’s predictions. We employ tree-based ensemble regressors as student models, as described in the Table 5. For brevity, we will often refer to these models by their specific names rather than repeatedly using the term ’regressor’ throughout the paper.

Table 5 Student models and their descriptions.

Estimating the complexity of tree-based ensemble models-an approximation and distillation quality score

Understanding the complexity of student models is essential to evaluating the quality of knowledge distilled from the teacher model. Black-box models, including various ensemble techniques, diverge from traditional likelihood-based frameworks and present challenges in directly assessing model complexity. This is mainly because the number of parameters in such models does not accurately represent their degrees of freedom. The concept of Generalized Degrees of Freedom (GDF), introduced by Ye70 and later applied to machine learning by Elder71, serves as a metric for assessing the complexity of models. For instance, in the case of a two-dimensional decision tree scenario, Elder71 has observed that combining multiple trees through bagging leads to an ensemble with a Generalized Degrees of Freedom (GDF) complexity that is lower than that of any single tree within the ensemble. In72, they employed GDF to estimate the number of parameters for the random forest model that was utilized to predict cell-type specific enhancer-promoter interactions by leveraging the information of protein-protein interactions between transcription factors.

Despite the utility of GDF in providing an estimate of model complexity, it has some challenges. Firstly, the sensitivity of GDF to perturbations in the data means that the degree to which GDF reacts can vary significantly depending on the specific modeling approach being used. This variability indicates that a GDF estimation method that works for one model type may not be suitable for another. In addition, the absence of a robust, universally applicable method for estimating GDF complicates its implementation across different data distributions and model architectures. These drawbacks highlight the complexity of accurately assessing model behavior in machine learning and the need for further research in developing more adaptable metrics like GDF73.

A standard metric for choosing models is the Akaike Information Criterion (AIC)74, which illustrates the trade-off between model complexity and goodness of fit. Models with reduced AIC values indicate a better balance between the model complexity and goodness of fit. It is computed using the Eq. 8. \(M_{k}\) denotes the model with dimension k. \(L(M_{k})\) is the likelihood corresponding to the model \(M_{k}\)

$$\begin{aligned} \operatorname {AIC}\left( M_k\right) =-2 \log L\left( M_k\right) +2 k \end{aligned}$$

(8)

However, one limitation of the Akaike Information Criterion (AIC) is its unsuitability for non-parametric model selection75. Models such as Random Forest are non-parametric76. It is a common misconception that non-parametric models have no parameters. They can be thought of as having an infinite number of parameters. This characteristic suggests that the complexity of non-parametric models can grow to capture increasingly precise information within the data as the number of data rises76. Few papers have computed the AIC for models such as Random Forest in77. This study developed a machine-learning model to simulate the effect of masks on motor sound, utilizing noise level data in decibels from various operation frequencies of motors at the National Synchrotron Radiation Research Center (NSRRC). Three group indicators were used to assess the learning performance: the Akaike Information Criterion (AIC), the Hannan-Quinn Information Criterion (HQIC), the Schwartz-Bayesian Criterion (SBIC), and the Akaike Information Criterion with Small Sample Correction (AICc). However, based on the information provided, the specific method used to determine the number of parameters (‘k’) for the AIC score is unclear.

When models are estimated using maximum likelihood, the choice of model based on minimizing the cross-validation error leads to asymptotically equivalent decisions as selecting the model that minimizes the AIC78. Based on this, the authors in73 argued that it should be possible to extract a measurement from \(l_{CV}\) (which denotes the sum over K folds of the log-likelihood of the validation subset that estimates model complexity). The equation in 9 denotes the asymptotic equivalence between AIC and leave one out cross validation (LOOCV). Based on this, the number of parameters p can be estimated using the Eq. (11). \(l_m\) denotes the maximum log-likelihood of the original (non-cross-validated) model, and \(l_{CV}\) represents the sum over K folds of the log-likelihood of the validation fold.

$$\begin{aligned} \begin{array}{r} \textrm{AIC}=-2 \ell _m+2 \hat{p} \approx -2 \ell _{\textrm{CV}} \\ \end{array} \end{aligned}$$

(9)

$$\begin{aligned} \begin{aligned}&-2 l m+2 p \approx -2 l c v \\&2 p \approx -2 l c v+2 l m \\&p \approx 2(l m-l c v) / 2 \end{aligned} \end{aligned}$$

(10)

$$\begin{aligned} \begin{array}{r} \hat{p} \approx \ell _m-\ell _{\textrm{CV}} \\ \end{array} \end{aligned}$$

(11)

In our work, we have employed tree-based ensemble regressors as student models. These are non-likelihood models. In73, the authors found the notion of applying GDFs to non-likelihood models to improve information-theoretic metrics of model fit (like AIC) was associated with the high cost of processing and produced inconsistent results. While cross-validation was a more direct method, it was less stable than GDFs. To determine the model complexity metric, they suggested repeated 10-fold cross-validation. Cross-validation is suitable for models that do not make likelihood assumptions since it can but need not, use the likelihood fit.

We build our methodology based on this idea. We utilize the sum of squared errors (SSE) to approximate the log-likelihood term. It suits our models that do not directly maximize the likelihood function. A higher maximum log-likelihood value indicates that the observed data is more probable under the model, which is interpreted as a better fit. A lower SSE suggests that the model’s predictions are closer to the actual observed values, which is also interpreted as a better fit.

Equation (12) shows the computation of model complexity with SSE. The \(SSE_{full}\) denotes the sum of squared errors on the training set, and \(SSE_{CV}\) denotes the SSE of the cross-validation. The logarithm helps to scale and normalize the SSE in relation to the number of observations ’n.’ In our experiments, we implemented a trial of 10-fold cross-validation recognizing the expensive computational demands of LOOCV. However, it does introduce some level of Monte-Carlo variability, resulting from not averaging all possible leave-one-out sets, as would be the case with LOOCV73. We observed slight variations in these estimates across different runs during our experiments. To ensure stable and reliable estimates, we recommend future researchers to conduct multiple runs, as suggested in73.

$$\begin{aligned} \hat{p} \approx n/2 \ln \left( \frac{S S E_{CV}}{n}\right) -n/2 \ln \left( \frac{S S E_{full}}{n}\right) \end{aligned}$$

(12)

These terms capture the fit by indicating how close the model’s predictions are to the actual data points, with the logarithm helping to scale and normalize the SSE in relation to the number of observations. The supplementary files provide additional results on how model complexity changes under varying parameters. Henceforth, the term ’number of parameters’ for non-neural models in this study will denote the effective complexity,\(\hat{p}\).

Based on the complexity approximated, we compute the distillation quality metric, which measures the effectiveness of the distillation process. Inspired by79, we employ a slightly modified version of the distillation quality metric to evaluate the performance of various student models. Its computation is shown in equation 13. Instead of using accuracy, we use a weighted F1 score in our metric when dealing with imbalanced datasets.

$$\begin{aligned} DS = \alpha \cdot \left( \frac{\text {student}_c}{\text {teacher}_c}\right) + (1 – \alpha ) \cdot \max \left( 0, 1 – \frac{\text {student}_{f1}}{\text {teacher}_{f1}}\right) \end{aligned}$$

(13)

\(\textrm{student}_c\) and \(\textrm{teacher}_c\) denote their respective complexities (in terms of parameters), and \(\textrm{student}_{f1}\) and \(\textrm{teacher}_{f1}\) denote their F1-scores (weighted). The approach of computing the number of parameters of our student models is described under section “Estimating the complexity of tree-based ensemble models-an approximation and distillation quality score”. The second term incorporates the max function to handle cases where the student outperforms the teacher. The authors in79 emphasize that the choice of the parameter \(\alpha\) is left to the designers, allowing them to prioritize either model size or accuracy according to their system’s requirements. For instance, a value of \(\alpha > 0.5\) would be appropriate if smaller model sizes are more critical. In our work, to balance the importance of model size and performance, we set \(\alpha = 0.5\), giving equal weight to these two factors. For balanced datasets, accuracy can be used instead of F1-scores to evaluate performance. In cases where the student outperforms the teacher, the ratio of student performance to teacher performance exceeds one. To address this, we have adjusted the score to ensure it remains non-negative. In our approach, a score of zero is achieved when the student model outperforms the teacher while maintaining a much smaller size than its teacher.

Ablation studies

We conducted three ablation studies, primarily focusing on cell graph data sets. The first study explored training with ensembled logits from the teacher and the best-performing student model. The second study aimed to analyze the differences in the importance of features when the models were trained using teacher logits compared to when they were trained using hard labels. The third study compares the effectiveness of transferring teacher knowledge via distillation into two types of student models: an Artificial Neural Network (ANN) and non-neural models.

Combining teacher and top student: ensemble model training

The goal of knowledge distillation from several teachers is to produce a good student who inherits the majority of the ensemble’s performance without raising the computational cost of inference. First, building highly predictive teacher ensembles is required to produce strong student models with distillation80. A few works focus on ensemble distillation on unlabeled datasets81,82,83. Since our study focuses on labeled data, we explicitly evaluate approaches relevant to labeled datasets for our distillation process, where the crucial problem is how to assign different weights to individual teachers within the ensemble81. In84, they proposed an ensemble model that unified three distinct knowledge distillation methods–feature-based, response-based, and relation-based on the CIFAR-10 and CIFAR-100 benchmarks. The distillation utilized a lightweight ResNet-20 student model with 0.27 million parameters and a ResNet-110 teacher model with 1.7 million parameters. The authors in85 trained an ensemble of various Multi-Task Deep Neural Networks (MT-DNNs (teachers)), achieving superior performance over any single model. Subsequently, they trained a single MT-DNN (student) through multi-task learning, effectively distilling knowledge from the ensemble of teachers. Wang et al.86 trained one segmentation teacher CNN on synthetic samples with accurately known ground truth fault labels and another classification teacher CNN on field samples with manually annotated labels. Following this, a classification student network was trained on samples created by aggregating the predictions from both teacher models through a voting mechanism. The authors in87 proposed MT-BERT, a novel approach to multi-teacher knowledge distillation focused on the compression of pre-trained language models. They devised a co-finetuning framework that simultaneously fine-tuned multiple teacher models employing a unified pooling and prediction module to align their output hidden states. This methodology enhanced the collaborative teaching of the student model. Chebotar and Waters88 discovered an effective ensemble of acoustic models comprising LSTM and CLDNN architectures developed with diverse training objectives, where the student model was a CLDNN. Initially, the research involved identifying the optimal fixed weights for merging the outputs of teacher models to maximize accuracy. The knowledge was later distilled into the student model using the soft labels generated by the ensemble. The authors in89 proposed a dynamic weighting approach for each teacher, demonstrating its effectiveness in logits-based and feature-based distillation through extensive experiments. They treated the process as a multi-objective optimization problem to find a more effective training direction.

For this ablation study, we consider both the CG-JKNN and the highest-performing student model as teacher models to investigate their combined impact on knowledge distillation. We adopt the methodology proposed in88, which involves identifying optimal fixed weights for merging the outputs of teacher models to maximize the F1 score on the validation set. Following this, we distill a student model from the ensemble output generated through this optimized combination. Equation (14) illustrates the method for aggregating outputs from the teacher GNN and LightGBM models. The detailed approach is shown in the algorithm 1.

$$\begin{aligned} L_{\text{ ensemble } }(x)=w_{\text{ gnn } } \cdot L_{\text{ gnn } }(x)+w_{\text{ lightgbm } } \cdot L_{\text{ lightgbm } }(x) \end{aligned}$$

(14)

\(L_{\text{ ensemble } }(x)\) is the ensembled output for a given input x. \(L_{g n n}(x)\) is the logit output from the GNN model for a given input x. \(L_{\text{ lightgbm } }(x)\) is the raw decision score output from the LightGBM model for a given input x. \(w_{\text{ gnn } }\) and \(w_{\text{ lightgbm } }\) are the weights applied to the outputs from the GNN and LightGBM models, respectively. It can be adapted to incorporate the outputs of other high-performing student models.

Algorithm 1
figure a

Optimal Weight Finding for Ensemble of Teacher GNN and Best Student model

Feature importance: comparing students trained with and without teacher guidance

We aimed to analyze the differences in feature importance of student models trained on teacher logits and their counterpart trained on hard labels. Literature suggests that students trained on logits are better equipped to mimic the behavior of the teacher model18. Thus, this analysis can also serve as an approach to explore how a student model trained on logits may partially act as a proxy for interpreting the teacher’s decision-making process. For this experiment, we selected the student model that performed best on the held-out test set. To determine feature importance, we utilized the “feature importances” attribute of the model. Additionally, to assess how some of these important features contribute to predictions for each class and the direction of their impact, we employed SHapley Additive exPlanations (SHAP) plots90. Our objective was not to compare these techniques but to leverage SHAP for a deeper understanding of how features influence model predictions. In future work, we plan to incorporate advanced techniques such as permutation-based methods (e.g., Boruta importance)91 and knockoff approaches92, as these methods provide a more robust and accurate assessment of a feature’s predictive abilities within a model93.

It is important to note that the student model can act as an interpretable approximation of the teacher by reflecting its emphasis on certain cell graph level or morphological features. However, it cannot leverage the graph structure and complex node relationships that the teacher model captures through message passing. Instead, the student operates solely on feature values and the logits provided by the teacher. It thus limits its ability to fully replicate the teacher’s reasoning process.

Comparing effectiveness of knowledge distillation into ANN vs. non-neural student models

In this ablation study, we selected an ANN as the neural student to ensure both model types rely solely on the features and implicit relational knowledge provided through the logits of the teacher GNN. This avoids the additional advantage of directly exploiting cell graph structures that a GNN would have and ensures that any observed differences in performance stem directly from the effectiveness of the distillation process.

We designed a shallow network with one hidden layer to maintain a smaller student model and its structure is illustrated in Fig. 3. The hyperparameters, such as hidden dimensions, alpha (which balances the two losses), and learning rate, were optimized using Optuna over 50 trials, selecting those that maximized the validation F1 score. We also constrained the hyperparameter search space to ensure that the ANN model parameters remained comparable to those of the non-neural student models.

Figure 3
figure 3

Architecture of our shallow ANN student model. The ellipses denote that additional neurons are present in the layer but are not explicitly illustrated for clarity.

Hinton et al.15 discovered that the effectiveness of the student model’s learning process is significantly enhanced when it is trained using both the soft target provided by the teacher model and the actual ground truth. This approach involves a combined loss function that integrates two key components: the traditional cross-entropy loss and a knowledge distillation-specific loss term.

The overall loss function for knowledge distillation can be expressed as shown in the equation 15.

$$\begin{aligned} L_{KD} = \alpha L_{CE}\left( p_s, y\right) + (1-\alpha ) \tau ^2 KL\left( p_s^\tau , p_t^\tau \right) \end{aligned}$$

(15)

Here, \(L_{C E}\left( p_s, y\right)\) represents the cross-entropy loss. The second component, \(\tau ^2 K L\left( p_s^\tau , p_t^\tau \right)\), is the knowledge distillation term. \(p_s^\tau\) and \(p_t^\tau\) denote the softened outputs of the student and teacher models, respectively, after applying the temperature scaling with parameter \(\tau\). KL stands for the Kullback-Leibler divergence, a measure of how one probability distribution diverges from a second, reference probability distribution. \(\alpha\) is a hyperparameter that controls the balance between the traditional cross-entropy loss and the knowledge distillation loss. In our work, we observed that logits before calibration already produced good results, and consequently, we set the temperature \(\tau\)=1.

Hinton et al.15 suggested using a weighted average between the distillation loss and the student loss by setting \(\beta =1-\alpha\), and in one of their experiments, they used \(\alpha =\beta =0.5\). Other works that utilize knowledge distillation treat this weight as a tunable parameter94,95,96. In our work, we treat the weight parameter \(\alpha\) as a hyperparameter. Additionally, we present results using a fixed \(\alpha\) value of 0.5.

Generalizability of knowledge distillation under various dataset complexities

To investigate whether all models benefit from knowledge distillation and assess the effectiveness of our approach across various dataset complexities, we conducted experiments on multiple datasets (cell graph and non-cell graph). These datasets presented challenges, such as distribution shifts, and structural complexities in training and testing graphs. Importantly, for Coauthorship datasets, we did not extract local graph features but instead utilized the original dataset features. This allowed us to test the efficacy of knowledge distillation in the absence of graph-specific features. The logits obtained from GNN trained on these coauthor networks could encapsulate rich information by reflecting relationships between node features (keywords) and the graph structure (Coauthorship network). For instance, if an author is involved in interdisciplinary work, their logits may encode soft probabilities across multiple fields, capturing the uncertainty or overlap between class labels.

Graph complexity

We hypothesize that for knowledge distillation to be effective when the teacher is a GNN learning from the graph, the graph must possess sufficient complexity. In such cases, the logits transferred from the GNN provide valuable information that student models can leverage.

According to the literature, graph complexity measures can be categorized into deterministic and probabilistic methods97. Deterministic approaches include Kolmogorov complexity, substructure counting, and generative models. Probabilistic methods involve entropy functions (such as Shannon’s entropy) applied to probability distributions over graph structures with intrinsic and extrinsic subcategories. In our work, we focus on graph energy, a concept originating from molecular and quantum chemistry, as a metric to evaluate how graph structural complexities affect knowledge transfer from a teacher GNN to student models98,99. It is computed using the Eq. (16).

$$\begin{aligned} C=\left( \frac{1}{|A|} \sum _{k=1}^{|A|} b_k\right) \sum \operatorname {SVD}(M) \end{aligned}$$

(16)

Here \(b_k\) represents the edge weights if any, |A| denotes the number of edges in the graph, and \(\operatorname {SVD}(M)\) is a vector of singular values of the matrix M98.

Distribution shift in the data

The distribution shift100,101,102 can be broadly categorized into three types: Covariate shift, label shift, and concept shift. The feature distribution changes in the covariate shift case, while the label distribution does not. On the other hand, label shift happens when the distribution of the labels varies while the feature distribution remains the same. Concept shift, also called conceptual drift, arises when the actual relationship between the inputs and labels evolves, reflecting a change in the underlying concept the model is attempting to capture. There exist multiple ways to detect covariate shifts. We can compare summary statistics or employ dissimilarity measures like Earth mover’s distance. For statistical rigor, hypothesis tests such as the Kolmogorov-Smirnov or Chi-squared tests are used to determine significant distributional differences103.

For this work, we utilized Kernel Principal Components Analysis (Kernel PCA) for dimensionality reduction, selecting the number of components that captured above 95% of the dataset’s variance. Subsequent univariate Kolmogorov-Smirnov tests, with Bonferroni correction104 applied to an alpha of 0.01, rigorously adjusted our significance levels to control the cumulative Type I error rate across multiple hypotheses. The mean of all significant KS statistics was computed to summarize the extent of covariate shift across the K dimensions. Moreover, for the computationally expensive TB and Placenta dataset, we subsampled 20,000 points to ensure the feasibility of the analysis while maintaining the representativeness of the original data. The mean KS statistic calculated may not fully reflect the entire degree of shift in the dataset. However, our primary goal was to demonstrate the presence of a shift.

To determine the covariate shift in non-cell graph-based datasets, we calculated the percentage of features with covariate shift by performing univariate KS tests directly on the scaled features. This was due to the high dimensionality of the dataset, as the large number of components required to achieve 95% variance capture would have made our initially proposed approach computationally expensive. For label shift detection, we employed the Chi-squared test105 to evaluate the consistency of class distributions between the different data subsets. This involved constructing a contingency table based on the frequency counts of each unique class in these subsets. After computing the Chi-squared statistic, we assessed the p-value to determine whether the observed distributional differences were statistically significant.

Can logit calibration enhance student guidance?

Neural networks produce poorly calibrated predictions that can be either overconfident or underconfident. GNNs can be miscalibrated too106. Calibration primarily aims to make predicted probabilities more reliable. In our study, we were particularly interested in investigating whether logit calibration could enhance the guidance provided to our student models. It is important to note that logit calibration does not impact the performance of the teacher model itself. Previous studies107,108 have demonstrated how calibration can impact models’ accuracy and other performance metrics. Additionally, the authors in109 introduced the concept of addressing mis-instruction through logit calibration. This work highlighted that enhancing target logits while preserving the relative proportions among non-target logits can significantly improve the utility of logits for knowledge distillation. These works primarily dealt with neural models as students. Wang et al.110 observed that GNNs tend to be underconfident, in contrast to the majority of multi-class classifiers, which are generally overconfident. This necessitated the use of various techniques to calibrate the logits. Guo et al.111 proposed temperature scaling to address the miscalibration issue found in modern neural networks. Kuleshov et al.112 introduced a straightforward calibration method based on isotonic regression. Another approach was ensemble-based temperature scaling113. Methods such as temperature scaling preserved accuracy by maintaining the per-node logit rankings unaltered114.

To achieve calibration, in this work, we employed isotonic regression and temperature scaling as post-hoc calibration methods. In traditional settings, isotonic regression is employed for binary classification tasks. To extend isotonic regression to multiclass scenarios, we adopt a one-vs-all strategy115,116. We measured the Brier score (Stratified) and negative log-likelihood before and after calibration, as they are proper scoring rules and provide a truthful measure of the accuracy of probabilistic predictions117. To learn the temperature T, it is considered best practice to use a validation set or perform cross-validation. We used 5-fold cross-validation (2 folds if the dataset is highly imbalanced) by splitting the training logits into train and validation folds. We learned two temperatures using the validation fold to optimize both the Brier score and the log loss. Our paper refers to the probabilities obtained after calibration using Eq. (17) as calibrated probabilities (calibrated probs). The overall score mentioned in the paper represents the mean of the scores calculated individually for each class.

$$\begin{aligned} \hat{p}_i = \frac{\exp \left( \frac{z_i}{T}\right) }{\sum _{j=1}^C \exp \left( \frac{z_j}{T}\right) } \end{aligned}$$

(17)

where \(\hat{p}_i\) represents the calibrated probability for class \(i\), \(z_i\) is the logit for class \(i\) (pre-softmax output of the model), \(T > 0\) is the temperature parameter learned using a validation set or cross-validation, and \(C\) is the total number of classes.

Experimental setup and hyperparameters

We implemented the models using the PyTorch framework118 and ran them on one NVIDIA A100 GPU. The hyperparameters of the teacher model were chosen with the assistance of Optuna119, a Python library for hyperparameter optimization. We ran 50 trials to optimize the model hyperparameters, aiming to achieve the highest weighted F1 score on the validation set for imbalanced datasets. We used the cross-entropy loss function during training when the class imbalance was mild/moderate. We utilized a weighted cross-entropy loss function for scenarios with extreme class imbalance. The teacher model was run for 80 epochs. We used an Adam optimizer. The hyperparameters of the teacher model associated with each dataset are tabulated in Table 6. The features were scaled using the standard scaler. As performance metrics, we evaluated the accuracy and weighted F1 score. The temperatures used to calibrate the logits are also presented. The first temperature minimizes the stratified Brier score, The second temperature minimizes the log loss.

To maintain smaller student models, we set the number of estimators in the students to 6, with the maximum depth varying between 8 and 16 (such as 8,12,16, etc) and the number of leaf nodes fixed at 50. However, we allowed the number of leaf nodes to be 300 for our complex TB dataset. The learning rate of the boosters was set to 0.3, while all other parameters were kept at their default values. The specific depths of student models are detailed in the results section corresponding to each dataset. It is important to note that the student model performances reported are specific to the chosen hyperparameter configurations. We acknowledge that the results could vary with a more extensive hyperparameter search.

Table 6 Teacher model hyperparameters and temperature values for datasets.

The edge homophily of the graphs used is shown in the Table 7. It is the ratio that measures the proportion of edges in a graph that connect nodes of the same class label. The equation to compute edge homophily is given in 18.

$$\begin{aligned} h=\frac{\mid \left\{ (u, v):(u, v) \in \mathscr {E} \wedge y_u=y_v\right\} }{|\mathscr {E}|} \end{aligned}$$

(18)

where: h denotes the edge heterophily score, \(|\mathscr {E}|\) is the total number of edges in the graph, (uv) represents an edge between nodes u and v, \(y_u\) and \(y_v\) are the labels of nodes u and v.

As stated in120, a high edge homophily ratio indicates strong homophily where \(h \rightarrow 1\) while a low edge homophily ratio indicates strong heterophily where \(h \rightarrow 0\) .

Table 7 Edge homophily ratios.



Source link

Leave a Reply

Your email address will not be published. Required fields are marked *