News

Masked Language Modeling foundation model for clinical data

29/1/2026
14 min

I. Introduction

At Scienta Lab, we leverage AI foundation models to push the boundaries of immune-mediated inflammatory diseases (IMID) research. One of the main features of these diseases is that they aren’t as clearly defined as one might believe. Rather, they usually form a continuum of diseases with strong symptoms’ overlaps [1, 2]. By using AI on real-world data, our goal is to uncover new associations between subgroups of IMIDs patients and find new ways to stratify those diseases (e.g. discover new biomarkers).

Having these objectives in mind, it is little wonder that we’re interested in using AI to represent patients via high-dimensional embeddings. Indeed, representation models have proved efficient to grasp intrinsic relationships and latent structures within textual data [3, 4, 5]. More specifically, working with such representations is quite convenient when one wants to analyze a patient’s evolution over time, perform prediction (e.g. drug response prediction), or apply clustering on a group of patients.

In this blogpost, we’ll walk you through the main steps we carried out to end up with a small in-house  medical encoder model.

efvvrevrv

II. Data

II.A. Electronic Health Record (EHR) dataset

We’re working with a UK EHR dataset called OPCRD [6]. This dataset contains the medical follow-up of 2,7M IMIDs patients. More precisely, it is a large .txt table with more than 3 billion entries, each of them representing one specific medical concept that was transcribed by a general physician, for a given patient, on a given date. The concepts can range from purely administrative ones (e.g. Referral letter), to very specific symptoms (e.g. Abscess of axilla) or disease diagnoses (e.g. Pemphigus paraneoplastica).

Each medical concept is represented by a unique ID. In our case, we’re working with the Unified Medical Language System (UMLS) ontology. For this specific ontology, codes look like this: C5401210 (Referral letter), C0263115 (Abscess of axilla).

Because OPCRD dataset uses a less convenient ontology called SNOMED, we took the time to map all the SNOMED codes to UMLS codes, taking care to lose as little information as possible.

II.B. Structuring the data

We transformed the input table into a patient-level dataset in which each entry contains all the visits of a given patient, each visit being represented by a list of medical concepts. Then, every time we train a model, when building the dataloader we make sure to transform the full patient’s follow-up to a single sequence of tokens (note that working with ULMS codes makes tokenization mapping easy: one code equals one token) which can be fed to a transformer encoder model. The figure below sums up (in a simplified way) the data transformation steps.

II.C. Curating the data

As you probably all know: garbage in, garbage out. Well, guess what… This has never been so true.

Vocabulary-level filtering

In the ideal world we’d love to live in, we could have obtained nice results using this freshly structured dataset. The truth is we had to cut off the vocabulary by a large margin in order to get rid of outliers concepts (examples are given in the paragraph below) and make the model less memory intensive. (NB: We noticed a significant difference in memory usage between a tokenizer with ~25,000 tokens and a tokenizer with ~60,000 tokens. Indeed, for a model having a hidden size of 256, reducing the vocab from 60,000 to 25,000 makes the embedding matrix decrease its number of parameters from 15.3M to 6.4M, hence the saving in memory usage.)

To this end, we implemented several filtering steps to make sure our dataset was as clean as possible. First we filtered out administrative concepts and other irrelevant concepts for our use case (e.g. “Letter sent to consultant” ; “Redundancy payment”). We think this step was of paramount importance because we had overrepresented and uninformative concepts which made the other interesting ones nearly invisible to the model). Then, we fixed a minimum of 100 appearance for a concept to be kept. Rare concepts were filtered out (e.g.  “Tenacious character”, “Wool Alcohol”, “Thoracic Neoplasms”, “Contusion of chin”).

As a result, we managed to cut off the vocabulary size from 270k SNOMED concepts (in the original table) to 36,058 UMLS concepts.

Sample-level filtering

Also, to prevent the model from focusing on outliers patients, we removed visits with more than 512 concepts (this step removed only 0.00005% of the visits, as the median and the 90th percentile number of concepts per visit are 2 and 21), as well as patients with less than 3 visits (~6,000 patients (=0.2%) were discarded).

As a result, we went from an original .txt OPCRD table of 2,691,123 patients to an arrow dataset of 2,685,031 patients (40 GB), which we split between train and val using a 98%-2% ratio.

We found a validation set of ~50,000 patients was more than enough to correctly evaluate our model while training it.

III. Training and predicting

III.A. Training the model

We trained a RobertaPreLayerNorm [7] model of 12.6M parameters from scratch using a classic Masked Language Modeling (MLM) training task (with a masking probability of 15%). We worked with a hidden size of 256 as we found a hidden size of 512 made training longer and more resource-intensive while it did not prove beneficial in terms of embedded information.

We used 8 V100 GPUs to perform training over 18 epochs. The training lasted 33 hours. We stopped it when both the training and validation loss seemed to have plateaued for a long time (see curves below)

About the context-window:

We decided to work with a context-window of 256 tokens. Obviously, our dataset contains patients with follow-ups longer than 256 concepts (see distribution below).

Length of patients’ follow-up in our arrow datatet.

To overcome this limitation, we divide the tokenized input sequence into chunks. (See figure below for a visual explanation). This means each patient is represented by N chunks of size 256 (on the figure: C1 contains the first 256 tokens). First, each of these chunks are passed through the model, and average pooling is performed to retrieve one embedding per chunk (a1 is the resulting embedding of chunk C1). Then, to make sure we still have an approximate self-attention between concepts belonging to different chunks, we perform cross-attention between the N chunks, as if we had a sequence of N tokens, each token representing the previously forwarded chunk (a1 becomes b1). Finally, we average pool the resulting embeddings of these N chunks to create the patient embedding. (See figure below).

Now you may ask, why not use a bigger context size? Well, first, RobertaPreLayerNorm has a context size’s limit of 512. So we had to work with chunks anyways. Then, we noticed that working with smaller chunks made training significantly faster. This is justified by the self-attention complexity in $O(n^2)$. Performing self-attention on a sequence of 512 tokens is nearly twice as long as performing self-attention on two chunks of 256 and then perform chunk’s cross-attention on 2 chunks.

Given that we wanted to have as much “pure” self-attention as possible, we decided not to decrease the context size too much and opted for 256.

About positional encoding

{{CTA-1}}

Regarding the positional encoding, we chose to work with the patients’ age, as they do in [8]. The main advantage is that it conveys to the model the information of both the ordering and the time step between visits. Moreover, we do not need to have an exact ordering of concepts within a visit because (i) we’re developing an encoder-only model and are not interested in generating ordered sequences (ii) we’re working with already encoded information for which order is less relevant than for natural language.

Assessing embeddings’ relevance

Even though the training curves look good, we wanted to make sure the learnt embeddings were relevant enough for our use case.

To this end, we fetched the embeddings of a set of concepts of interest (~150 concepts representing diseases). We then projected these embeddings in a 3D space, and used Tensorboard built-in selection tool which, when applied to one point, highlights the top-N closest points in the original vector space. We used cosine similarity as the distance metric and N=10 to run a sanity check on the diseases of interest.

The checks made us confident that the model had learnt relevant embeddings. Figures below show examples for two concepts of interest. On the right side are displayed the 10 closest concepts (see red gradient list. The higher in the list, the closer to the studied concept). The left side shows where these 10 closest points are located in the UMAP 3D projection of the embeddings.

{{HIGHLIGHT-1}}

III.B. Exploring patient’s embeddings

Now that we have a trained model, our goal is to compare patients’ embeddings output by this model and see how it can help us stratify the studied population.

Skin-IMIDS dataset

For the purpose of this blog post, we’ll use a dermatology-specific subpopulation (extracted from the 2.6M arrow dataset). Namely, we’re interested in working with patients suffering from skin-IMIDs like atopic dermatitis (AD), psoriasis (Pso), hidradenitis suppurativa (HS), prurigo nodularis (PN), and bullous pemphigus (BP). We extracted this population of patients and ended up with a skin-IMIDs dataset of 53,228 patients.

Considering specific cui types only

We used the trained model to obtain patients’ embeddings, then we clustered the patients based on their embeddings and tried to interpret the obtained clusters. We quickly noticed that our embeddings did not reflect symptoms and skin-diseases diagnosis enough as we did not end up with clear diseases’ clusters. Instead, the biological concepts (e.g. lab tests) accounted for the main part of the embeddings. This makes sense since these were the most frequent concepts, given that our data is collected from general physicians appointments.

To overcome this issue, we leveraged one of UMLS ontology’s features: concepts types. Indeed, each concept belongs to a category of concepts (e.g. “Biological Function”, “Body Location or Region”, “Laboratory or Test Result”, “Disease or Syndrom”), which makes it very convenient when one wants to consider specific concepts only.

To this end, we slightly adapted the forward pass. First, all tokens in the sequence attend to each other (as usual). However, when averaging each tokens embeddings, we only consider medical concepts belonging to symptoms and diagnoses categories (in practice, we set the embeddings of useless tokens to the null vector and only divide by the number of non-null vectors, so that these irrelevant concepts do not account in the average pooling). See figure below for a visual explanation.

We end up with patient embeddings that are “diagnoses and symptoms”-oriented.

Considering interesting concepts neighborhood

The method previously presented made embeddings way more disease-oriented. Yet, we wanted to go further in this direction.

We decided to manually identify ~250 concepts of high interest based on what we’d like the embeddings to focus on (symptoms / diagnoses of skin-IMIDs in our case). To extend this list, we then fetched the 10 closest concepts for each of them (using cosine similarity), and ended up with a list of 1294 interesting concepts.

We performed the same adapted forward pass (described above), but this time by considering concepts belonging to the interesting concepts list.

We observed that embeddings were even more disease-oriented with this method, and decided to stick with it.

Application: skin-IMIDs population clustering

We predicted embeddings for the 53,228 skin-IMIDs patients. Then, we used a spectral clustering algorithm to cluster patients. We chose this algorithm mainly because it can handle non-convex clusters which, given that we work in relatively high dimension (256), are more than likely to appear.

As a matter of fact, we also tried agglomerative clustering and obtained less meaningful clusters than with spectral clustering.

We deliberately chose a large number of clusters (n=20) to see what clusters we would obtain.

Clusters interpretation

To further explain the obtained clusters, we decided to work with word clouds. For each cluster, we computed a concept count, and kept the 50 most important concepts. We defined concept’s “importance” as the resulting p-value of the chi-square test of independence [10] between the observed frequency of a concept (in the cluster of interest) and the expected frequency (defined as the frequency of appearance if the concept were uniformly distributed across clusters). The lower the p-value, the more important is the concept for the associated cluster. We display concepts with a font size proportional to their importance. Please note that we only display concepts with a p-value lower than 0.1, hence some word clouds contain less than 50 concepts.

For the sake of confidentiality of this project, we won’t display here detailed results of the interpretation.

To better grasp the topology of the clustering in the 2D projection space, we drew the figure below by analyzing closely the word clouds of each cluster.

Is our clustering biased ?

The objective of this section is to explain which validation we tried to implement to detect biases in our clustering.

To do so, we found it quite convenient to color the exact same scatter plot with colors based on statistics like age, region, sex, number of concepts, treatment taken, and ethnicity.

Two statistics caught our attention: age and number of concepts (see the colored plots below).

Obviously, these two statistics are highly correlated. But what is also striking, is the higher prevalence of higher values (dark orange) in the center (and same for lower values (dark blue) in the extremes).

We define this phenomenon as “collapsing”. Indeed, the older the patient (and the more concepts in their follow-up), the more “averaged” will their embedding be (as it is roughly defined as an average of the embedding of all their concepts).

Hence, older patients tend to be located in the center, among with patient having only “general” concepts in their follow-up.

One solution to this problem would be to implement weighted average pooling to give concepts a weight inversely proportional to their frequency, so as to temper the influence of very common concepts and highlights the presence of rarer ones.

Other solutions could involve studying the evolution of a patient rather than its embedding at a fixed time point.

IV. Conclusion

The achieved results show LLMs are suitable to grasp medical information from raw clinical visits and embed it relevantly in a high-dimensional vector space. The clustering algorithm applied to the patients’ embeddings highlighted both very distinctive clusters of patients as well as well-known overlaps, confirming the model’s ability to compare IMIDs patients.

Data access and rights

This study is based wholly on data from the Optimum Patient Care Research Database (www.opcrd.co.uk) obtained under a limited licence from Optimum Patient Care Limited and its execution is approved by recognized experts affiliated to the Respiratory Effectiveness Group. However, the interpretation and conclusions contained in this report are those of the authors alone.

References

  1. McGonagle D, McDermott MF. A proposed classification of the immunological diseases. PLoS Med. 2006 Aug;3(8):e297. doi: 10.1371/journal.pmed.0030297. PMID: 16942393; PMCID: PMC1564298.
  2. Moutsopoulos HM. Autoimmune rheumatic diseases: One or many diseases? J Transl Autoimmun. 2021 Oct 28;4:100129. doi: 10.1016/j.jtauto.2021.100129. PMID: 35005593; PMCID: PMC8716565.
  3. https://arxiv.org/abs/1810.04805
  4. https://aclanthology.org/D19-1410/
  5. https://arxiv.org/abs/2104.08821
  6. https://www.opcrd.optimumpatientcare.org/
  7. https://huggingface.co/docs/transformers/en/model_doc/roberta-prelayernorm
  8. Li, Y., Rao, S., Solares, J.R.A. et al. BEHRT: Transformer for Electronic Health Records. Sci Rep 10, 7155 (2020). https://doi.org/10.1038/s41598-020-62922-y
  9. Robinson CA, Love LW, Saleh HM, et al. Nummular Dermatitis. [Updated 2024 Mar 1]. In: StatPearls [Internet]. Treasure Island (FL): StatPearls Publishing; 2024 Jan-. Available from: https://www.ncbi.nlm.nih.gov/books/NBK565878/
  10. https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.chi2_contingency.html
Lorem ipsum dolor sit amet consectetur. Sem pulvinar nisi dui erat lacus eu cras justo id. Faucibus duis ut habitasse cras tortor