key: cord-0191860-wht19u92 authors: Tetteh, Enoch; Viviano, Joseph; Bengio, Yoshua; Krueger, David; Cohen, Joseph Paul title: Multi-Domain Balanced Sampling Improves Out-of-Distribution Generalization of Chest X-ray Pathology Prediction Models date: 2021-12-27 journal: nan DOI: nan sha: 6847ee6ea14c65e504f56e19df24680671df77a4 doc_id: 191860 cord_uid: wht19u92 Learning models that generalize under different distribution shifts in medical imaging has been a long-standing research challenge. There have been several proposals for efficient and robust visual representation learning among vision research practitioners, especially in the sensitive and critical biomedical domain. In this paper, we propose an idea for out-of-distribution generalization of chest X-ray pathologies that uses a simple balanced batch sampling technique. We observed that balanced sampling between the multiple training datasets improves the performance over baseline models trained without balancing. Pathology detection or classification in medical imaging using deep learning [5] , [4] continues to be an open research challenge. Although the field of computer vision has progressed significantly owing to advances in model architecture, optimization techniques, and data augmentation, learning relevant feature representations from visual data still remains a challenge [2] [9] . The field of medical imaging using deep learning suffers from the same issues, since they rely mostly on the same deep learning approaches inspired by studies on general purpose datasets, like ImageNet [13] . The issue becomes more challenging when test data is subject to distribution shifts [12] . When confronted with chest X-rays from different datasets, a straightforward approach, followed by [2] , is to merge these datasets and form mini-batches by sampling uniformly from the merged dataset. The question is: can we find a more robust and efficient way of learning representations from medical images by accounting for which dataset each example came from? In this work, we examine how balanced batch sampling from each training dataset can improve a model's generalization on out-of-distribution (OoD) chest X-ray datasets. We compare this to [2] and a baseline model trained without balanced batching. Our work shows that, by training vision algorithms on chest X-rays using balanced mini-batches, we may achieve performance gains during inference on out-of-distribution chest X-ray datasets. We aim to classify 4 chest X-ray pathologies, namely Cardiomegaly, Consolidation, Edema, and Effusion. We compare our approach to two baselines that follow SOTA medical image classification approaches. We describe the baselines in 2.1. This choice of pathologies is strictly because all the datasets in this experiment include labels for all of these pathologies, and not for any other particular reason. The training, validation and test datasets all come from different distributions. This is to ensure that test and validation datasets are out-of-distribution with respect to the training data. For example, when training on ChestX-ray8 and CheXpert datasets, we validate on MIMIC-CXR dataset and test on PadChest. Datasets We consider four publicly available chest X-ray image datasets for this experiment: ChestX-ray8 [14] , CheXpert [6] , MIMIC-CXR-JPG [7] , and PadChest [1] . Table 1 displays the details of our datasets. We use a subset of 27,520 images from each dataset for training, 11,008 for validation, and 13,760 for inference. We end up with training samples of 55,040 images, since we use two datasets as the train set. The training data is sampled sequentially from the original dataset. On the other hand, the entire validation and test datasets are sampled from the data loaders as batches. We use the TorchXRayVision [2, 3] library, which is specialized in handling chest X-ray images, to load the raw data. In order to increase the diversity of the training dataset to improve the robustness of the trained model [15] , we perform the following basic We perform a leave-a-dataset-out cross validation, and perform training using a DenseNet-121 architecture from the torchvision library. The model is fine-tuned [10, 16] using ImageNet weights, and without any modification to its architecture except the input channel is changed to 1 (for gray-scale images). We use an Adam [8] optimizer with a fixed learning rate of 1e-03, weight decay of 1e-05, and amsgrad set to true. We run all our experiments for 200 epochs with a batch size of 64, using a binary cross-entropy with logits as our loss function. We perform early stopping [11] , and use the final best validation model state, for our inference. We run the experiments with three different seeds and report the average. Unless otherwise stated, all results are on the test datasets. We compare with 2 different baselines, one taken from previous work and one which we train following [2] , a state-of-the-art method, on our specific datasets: Baseline XRV is a DenseNet-121 model, from the TorchXRayVision library [3] that is trained on 8 chest X-ray datasets (including all the datasets we use in our experiments), and tasked to classify 18 chest X-ray pathologies. We only performed inference using this model for comparison with our models. Random Batch Sampling follows previous approaches of studies performed on chest X-ray pathology classification using different data distributions, which involves merging multiple datasets into a larger one for training. We merge two datasets for training, and the remaining two are use for validation and inference respectively. Balanced Batch Sampling We create two training environments, one for each of our training datasets. At each training step, we sample data from each of the environments, compute the individual losses and back-propagate using the sum of the losses from the environments. In this section, we present and discuss the findings of our work. Table 2 shows the results of our experiments. Our results suggest that out-of-distribution generalization performance may be improved by using a balanced batching technique -sampling data from each environment/dataset equally and computing the sum of the losses. From Table 2 , we can observe performance gains from the balanced batch sampling model over the random batch sampling model. By using balanced batching, we are able to outperform the random sampling approach in all six experimental settings, in this work. Randomly sampling mini-batches from a merged dataset may result in data bias, because the sampled data may come from only one of the multiple distributions available or there may be fewer samples from some distributions than the other. This may result in a model biased towards a certain distribution. On the other hand, the balanced batch sampling uses a stratified sampling approach, which ensures the algorithm sees data from each distribution at every iteration during training. The model in this case is less/not biased towards any of the distribution. The training datasets themselves were balanced for both balanced and random batch sampling. So it appears the overall balancing is not as impactful as the balancing of the mini-batches passed to the algorithm. The challenge of sample imbalance for each task is taken care of by computing a weighted loss. Also, although the Baseline XRV model is trained on a much larger data, and overlaps the test data, the average AUC of our model trained with balanced batching is as almost good as the XRV model. This research uses only previously public data, so there are no privacy concerns. We do not foresee any negative societal impact as a result of the research described in this work. Padchest: A large chest x-ray image dataset with multi-label annotated reports On the limits of crossdomain generalization in automated X-ray prediction TorchXRayVision: A library of chest X-ray datasets and models Deep learning Deep learning Chexpert: A large chest radiograph dataset with uncertainty labels and expert comparison Mimic-cxr: A large publicly available database of labeled chest radiographs Adam: A method for stochastic optimization. CoRR, abs/1412 A unifying view on dataset shift in classification A survey on transfer learning Early stopping -but when? Dataset shift in machine learning Imagenet large scale visual recognition challenge Chestx-ray8: Hospitalscale chest x-ray database and benchmarks on weakly-supervised classification and localization of common thorax diseases Understanding data augmentation for classification: When to warp? How transferable are features in deep neural networks? ArXiv This research is based on work partially supported by the CIFAR AI and COVID-19 Catalyst Grants. We would like to thank Mila (Quebec AI Institute) for providing computational resources and support that contributed to these research results.