Recruiting Digital Patients: My Journey into AI in Healthcare (Step 4)
From exploring AI in recommender systems to taking a plunge into healthcare, my journey into AI has been both exhilarating and eye-opening. Step 4 of my journey brought me into uncharted territory: using AI for real-world medical applications.
This was more than just a technical milestone; it was my first chance to see AI interact with real-world medical challenges.
Guided by an amazing mentor at SRI, I found new layers to AI that I hadn’t encountered before. This experience taught me that complexity doesn’t always mean improvement, how metrics can sometimes deceive, and the importance of tools like Grad-CAM to understand what our models see.
This project opened the door to my first major research study, titled ‘Recruiting Digital Patients: The Effectiveness of Training Models on AI-Generated Cardiac MRI Data,’ with the abstract published in The Young Researcher. Here’s a look into how I approached this project, the challenges I faced, and the insights I gained with my mentor’s guidance along the way.
The Student Research Institute (SRI) by the Harvard Undergraduate OpenBio Laboratory is a virtual summer program to increase access to research for high school students in the natural and physical sciences.
The Healthcare AI Challenge
Imagine a world where AI has fully transformed medicine — a world where diagnoses are enhanced by predictive models and treatment plans are customized for every individual. We’re on the brink of this future, yet a crucial obstacle stands in the way: access to the vast, high-quality patient data needed to fuel these advancements.
Here lies the dilemma: while patient data is essential for developing life-saving AI applications, strict privacy laws and ethical responsibilities make it nearly impossible to share real patient information.
Our project set out to tackle a groundbreaking question: Can we create “digital patients” to train AI models, allowing us to push forward medical AI while protecting real patient privacy? Guided by my mentor, I set out on this ambitious path.
Enter synthetic data. Using Generative Adversarial Networks (GANs), we can generate highly realistic artificial medical images that closely resemble real data, but without any privacy compromises. Synthetic data offers a promising solution to this data-access challenge, and my project focused on evaluating its potential. Specifically, I investigated whether machine learning models trained exclusively on AI-generated cardiac MRIs could accurately predict critical clinical metrics — like the Left Ventricular Sphericity Index (SI) — when tested on real-world data.
Building the Machine Learning Pipeline
Our pipeline begins with data preparation, involving two distinct datasets:
Real-World Validation Data
For testing, we used real MRI images from the UK Biobank, each labeled by two physicians. This dual labeling provided a crucial benchmark — allowing us to compare our model’s predictions against the natural variation between human experts. The SI labels were provided to us by the authors of this paper on cardiac sphericity as an early marker of cardiomyopathy.
The Left Ventricular Sphericity Index (SI) is a crucial metric in cardiology used to assess the shape of the heart’s left ventricle; deviations in this shape can signal early stages of heart conditions like cardiomyopathy, making SI a key predictor in cardiac health assessment.
Synthetic Training Data
For training, we worked with AI-generated cardiac MRI images created by the authors of the GANcMRI model. The critical step in our work was calculating the Sphericity Index (SI) for these synthetic images. We developed a custom application to label these synthetic images by marking the short and long axes of the left ventricle. While labor-intensive, this precise labeling was essential for our training data’s quality.
💡 Pro Tip: When working with medical imaging data, start by understanding the clinical metrics you’re trying to predict. I spent the first week just learning about the Sphericity Index and its importance in cardiac assessment — this knowledge proved invaluable when designing our labeling application.
In testing, we used both synthetic and real MRI data to assess how well the model trained on synthetic data could generalize to real-world cases.
Figure 1 illustrates our complete pipeline: starting with synthetic data generation and labeling, progressing through data preparation and augmentation, model training with ResNet50, and finally validation on real cardiac MRI data.
The Power of Data Augmentation: After preparing our datasets, the next crucial step was enhancing our training data through augmentation. When working with medical imaging, particularly with limited datasets, data augmentation becomes essential for model robustness. Our augmentation pipeline focused on realistic variations that could occur in cardiac MRI acquisition, including slight rotations, contrast adjustments, and controlled blur effects. Here’s how we implemented these transformations:
from torchvision import transforms
import random
from torchvision.transforms import functional as TF
def get_train_transforms(mean, std):
return transforms.Compose([
transforms.RandomRotation(degrees=15),
transforms.RandomApply([lambda img: TF.adjust_contrast(img, contrast_factor=random.uniform(0.8, 1.2))], p=0.5),
transforms.RandomApply([lambda img: TF.adjust_brightness(img, brightness_factor=random.uniform(0.8, 1.2))], p=0.5),
transforms.RandomApply([lambda img: TF.gaussian_blur(img, kernel_size=(5, 9), sigma=(0.1, 5))], p=0.3),
transforms.Normalize(mean=mean.tolist(), std=std.tolist())
])
💡 Pro Tip: Only augment training data to help your model learn to generalize. Keep validation and test data unaltered so they represent the real data distribution.
Model Architecture and Training
This phase was a journey of discovery, showing me firsthand that more complex models aren’t always better. Starting with ConvNeXt, I was optimistic because of its depth and sophistication. But I quickly ran into overfitting — mean squared error (MSE) looked impressive, but the actual predictions were off.
Switching to ResNet18 simplified the architecture and reduced overfitting, delivering better consistency. Finally, ResNet50 struck the right balance, giving me the precision I needed without excessive complexity.
He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. https://doi.org/10.1109/CVPR.2016.90
While ResNet50 is traditionally known for classification tasks, we discovered its depth and skip connections made it remarkably adaptable for our regression task. By modifying the final layer, we transformed this powerful architecture into a precise tool for predicting the Sphericity Index.
Metrics Matter (and Can Mislead): Here’s another eye-opener: one metric can sometimes appear impressive without reflecting real-world performance. Initially, ConvNeXt’s MSE appeared promising, yet the predictions didn’t align. This was my first experience learning to look beyond metrics and consider the model’s actual output.
Here’s an overview of the core model architecture:
import torch.nn as nn
import pytorch_lightning as L
import torchvision
class CNN(L.LightningModule):
def __init__(self, model_architecture, loss_function, learning_rate, weight_decay):
super().__init__()
self.model = self.initialize_model(model_architecture)
self.loss_function = loss_function
self.learning_rate = learning_rate
self.weight_decay = weight_decay
def initialize_model(self, model_architecture):
if model_architecture == 'convnext_small':
model = torchvision.models.convnext_small(pretrained=True)
model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, 1)
elif model_architecture == 'resnet18':
model = torchvision.models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 1)
elif model_architecture == 'resnet50':
model = torchvision.models.resnet50(pretrained=True)
model.fc = nn.Sequential(nn.Dropout(0.5), nn.Linear(model.fc.in_features, 1))
else:
raise ValueError("Unsupported model architecture")
return model
💡 Deep Dive: Transfer Learning Strategy The ResNet50 architecture choice reflects a sophisticated transfer learning approach. By using
pretrained=True
, we leveraged ImageNet pre-training, allowing the model to begin with a rich set of learned image features. Only the final fully connected layer was modified for our regression task.Dropout Integration We strategically placed a dropout layer (p=0.5) before the final linear layer. This serves as a powerful regularization technique, particularly important when fine-tuning a large pre-trained model on a smaller medical dataset to prevent overfitting.
Training the model required careful selection of architecture, loss function, optimizer, and hyperparameters to achieve the best performance. I experimented with different configurations, refining the setup over multiple iterations to find the optimal balance.
The final setup included ResNet50, Huber Loss, AdamW optimizer with a learning rate of 0.001 and a batch size of 16. We used PyTorch Lightning to manage our training process, implementing a learning rate scheduler that monitored validation loss. The model converged in 80 epochs.
The Loss Function Journey
The choice of loss function proved highly dataset-specific — a crucial lesson that would influence my future projects. In our case, after experimenting with various options, I found:
- MSE (Mean Squared Error) was too sensitive to outliers in our SI measurements
- MAE (Mean Absolute Error) proved overly stable, missing subtle variations in the data
- Huber Loss struck the ideal balance, handling both small variations and occasional outliers in physician measurements
💡 Pro Tip: Selecting a loss function is highly specific to the dataset.
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
if self.loss_function == 'mse':
loss = F.mse_loss(y_hat.squeeze(), y)
elif self.loss_function == 'l1':
loss = F.l1_loss(y_hat.squeeze(), y)
elif self.loss_function == 'huber':
loss = F.huber_loss(y_hat.squeeze(), y, delta=1.0)
else:
raise ValueError(f"Unsupported loss function: {self.loss_function}")
self.log("train_loss", loss, on_epoch=True, on_step=False)
return loss
Learning Rate Insights
The optimization strategy revealed interesting patterns in model learning. Using PyTorch’s ReduceLROnPlateau scheduler, we monitored validation loss to adaptively adjust the learning rate:
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
factor=0.1,
patience=20,
verbose=True
)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'monitor': 'val_loss',
}
}
Our initial learning rate of 0.001 with AdamW provided stable training, while the scheduler’s patience parameter of 20 epochs emerged from careful observation. Through learning curve analysis, I discovered that meaningful improvements often occurred between epochs 15–18. Setting patience to 20 struck the right balance — giving the model enough time to find better solutions while preventing wasteful training iterations when improvements plateaued.
Understanding What Our Model Sees
I had always heard about model interpretability but hadn’t truly appreciated its value until this project. Grad-CAM, a visualization technique, allowed me to see which parts of the MRI images the model focused on.
Selvaraju, R. R., Cogswell, M., Das, A., Vedantam, R., Parikh, D., & Batra, D. (2017). Grad-CAM: Visual explanations from deep networks via gradient-based localization. https://doi.org/10.1109/ICCV.2017.74
Figure 1 demonstrates the Grad-CAM visualization from the fourth to last convolutional layer of our ResNet50 model. By systematically examining the Grad-CAM outputs across intermediate layers, I traced the network’s attention as it progressively converged on the left ventricle. Seeing the final heatmap’s precise localization over the left ventricle felt like a huge aha moment.
Beyond the strong accuracy metrics, we now had visual confirmation that our model was learning to identify clinically relevant cardiac features. This alignment between the model’s focus and clinical relevance not only validated the model’s learning process but also enhanced trustworthiness in its predictions — a critical factor for AI in healthcare, where clinicians need to understand and rely on the model’s decision-making to inform patient care.
Evaluating Performance and Final Insights
After training, I evaluated the model on both synthetic and real-world datasets. The results were highly encouraging. When tested on synthetic data, the model achieved an impressive mean squared error (MSE) of just 0.0015. This high accuracy validated the model’s ability to learn from AI-generated cardiac MRIs effectively.
More importantly, the model’s performance remained strong when applied to real-world MRI data, achieving an MSE of 0.0089. To put this in perspective, this accuracy is comparable to the natural variability observed between two independent physicians’ measurements, underscoring the clinical potential of this AI model. These results provide a glimpse into the future possibilities of using synthetic data to train reliable and accurate medical AI models.
Lessons Learned
- Complexity is not always the answer: Finding the simplest model that performs well can reduce overfitting and improve interpretability.
- Evaluate metrics holistically: Single metrics can be misleading; look at the model’s predictions in real-world contexts. For instance, although ConvNeXt’s MSE looked promising, its predictions didn’t align well with actual patient outcomes, highlighting the importance of looking beyond numbers.
- Tune with intent: Hyperparameter tuning can turn a good model into a great one by balancing accuracy and stability.
Acknowledgement and Looking Ahead
This project wouldn’t have been possible without my mentor, who did more than just guide me — he listened, gauged my capabilities and challenged me! His support helped me navigate tough decisions and pushed me to grow every step of the way. As I look ahead to exploring Alzheimer’s research, I’m excited to bring these lessons forward, grateful for everything this experience has taught me.
This article is part of an ongoing series documenting my journey into AI and Machine Learning as a high school student. Follow along to see my progress and learn with me!