key: cord-0308098-4jrbide7 authors: Dinsdale, Nicola K.; Jenkinson, Mark; Namburete, Ana I. L. title: Deep Learning-Based Unlearning of Dataset Bias for MRI Harmonisation and Confound Removal date: 2020-10-11 journal: bioRxiv DOI: 10.1101/2020.10.09.332973 sha: 82e0567c423227cfcf46ff9548d70c4002349437 doc_id: 308098 cord_uid: 4jrbide7 Increasingly large MRI neuroimaging datasets are becoming available, including many highly multi-site multi-scanner datasets. Combining the data from the different scanners is vital for increased statistical power; however, this leads to an increase in variance due to nonbiological factors such as the differences in acquisition protocols and hardware, which can mask signals of interest. We propose a deep learning based training scheme, inspired by domain adaptation techniques, which uses an iterative update approach to aim to create scanner-invariant features while simultaneously maintaining performance on the main task of interest, thus reducing the influence of scanner on network predictions. We demonstrate the framework for regression, classification and segmentation tasks and two different network architectures. We show that not only can the framework harmonise many-site datasets but it can also adapt to many data scenarios, including biased datasets and limited training labels. Finally, we show that the framework can be extended for the removal of other known confounds in addition to scanner. The overall framework is therefore flexible and should be applicable to a wide range of neuroimaging studies. sites leads to an undesirable increase in non-biological variance, even when attempts have to acquire the data. Following the framework introduced by Ben-David et al [4] , domain minimising the loss on the main task. Placed between the feature extractor and domain clas- Figure 1 : General network architecture. The network is formed of three sections: the feature extractor with parameters Θ repr , the label predictor with parameters Θ p , and the domain predictor with parameters Θ d . X p represents the input data used to train the main task with labels y p , and X d represents the input data used to train the steps used to unlearn scanner with labels d. Three loss functions are used in the training of the network. The first loss function is for the primary task and is conditioned on each domain such that the training is not driven by the largest dataset: L p (X p , y p ; Θ repr , Θ p ) = N n=1 1 S n Sn j=1 L n (y j,n ,ŷ j,n ) (1) where N is the number of domains and S n is the number of subjects from the domain n 133 such that y j,n is the true task label for the j th subject from the n th domain. X p represents 134 the subset of the the total dataset X used for the main task, which in all cases will be the 135 full set of X for which we have main task labels y p available. 136 The domain information is then unlearned using two loss functions in combination. The 137 first is the domain loss that assesses how much domain information remains in Θ repr and 138 simply takes the form of the categorical cross-entropy: where p n are the softmax outputs of the domain classifier. The softmax outputs are also 140 used in the confusion loss, which removes domain information by penalising deviations from 141 a uniform distribution: Therefore, the overall method minimises the total loss function: L(X p , X u , y p , d u , Θ repr , Θ p , Θ d ) = L p (X p , y p ; Θ repr , Θ p ) where α and β represent the weights of the relative contributions of the different loss func-143 tions. Equations (2) and (3) cannot be optimised in a single step because they act in direct 144 opposition to each other. Therefore, we have to update the loss functions iteratively and 145 the total method results in three forward and backward passes to be evaluated per batch. 146 Each loss function is used to update a subset of the parameters, with the rest being frozen. In the first stage, the parameters in the feature extractor and label predictor are updated to converge on a feature representation that is invariant to the acquisition scanner. As shown by the loss functions, different sets of data can be used to evaluate the different 154 loss functions. This not only means that we can unlearn scanner information for data points 155 for which we do not have main task labels but also means that we can unlearn scanner 156 information using curated (or 'matched') subsets of data points when potentially problematic 157 biases exist in the data as a whole. These scenarios are explored in the experiments to follow. the feature representation without removing information that is discriminative for the age 188 prediction task. In this scenario, all three loss functions can be evaluated on a single combined dataset X 190 where we have labels y for the full set and know the acquisition scanner for all data points 191 d, meaning that the overall method minimises the total loss function: Figure 3 : The three biased datasets used in this experiment with 5 years, 10 years and 15 years overlap. Only the shaded overlap region is used for unlearning scanner information and all the data points are used to evaluate the main loss function. When the distribution of the data for the main task is similar across all data points, 202 unlearning can simply be completed on all of the data points. However, where there exists a 203 large difference between the two domains, such that the main task label is highly indicative 204 of the scanner, it is likely that the unlearning process will also remove information that 205 is important for the main task. This latter scenario could, for instance, be where the age 206 distributions for the two studies are only slightly overlapping or where nearly all subjects 207 with a given condition are collected on one of the scanners. To reduce this problem, we utilise the flexibility of the training framework, and, whilst 209 evaluating the main task on the whole dataset, we perform the scanner unlearning (Eqs. (2) 210 and (3)) on only a subset of the data. For the case of age prediction, we perform unlearning 211 only on the overlapping section of the data distributions. If we were to consider the case 212 where we had data from both subjects and healthy controls and, for instance, most of the 213 subjects had been scanned on one of the two scanners, we could perform unlearning on only 214 the healthy controls rather than the whole dataset. As we do not need main task labels for 215 the data used for unlearning, unlearning could be performed on a separate dataset so long 216 as the scanner and protocol remained identical. 217 We sample the Biobank and OASIS datasets so as to create three degrees of overlapping 218 datasets: 5 years, 10 years and 15 years (Fig 3) and test on the same test sets as before, Having demonstrated the network on three scanners, we now demonstrate the network on 224 a multi-site dataset, ABIDE [11] . We split the data into 90% training and 10% testing across 225 Figure 4 : The general network architecture can be adapted to allow us to also remove other confounds from the data. In addition to the datasets used for the main task and the scanner unlearning, we add a data pair where X c are the input images used for deconfounding and y c are the confound labels. each scanner and process as described in Section 2. In addition to unlearning scanner information to harmonise the data, we can also adapt Therefore, the overall method will minimise the loss: where we consider J different confounds we wish to remove, γ j is the weighting for the 248 classification loss L c j for the j th confound, and similarly φ j is the weighting of the confound 249 confusion loss L Cconf j for the j th confound. We demonstrate this with sex as the confound 250 and so the classification loss can simply be the binary cross-entropy loss. such that sex prediction is now the main task, we can consider the process of removing the 272 age information. The continuous age labels could be approximately converted into categorical labels by 274 binning the data into single-year bins spanning across the age range. This, however, would 275 not encode the fact that a prediction of 65 for a true label of 66 encodes more true age 276 information than a prediction of 20. We therefore convert the true age labels into a softmax 277 label around the true age, normally distributed as a N (µ, σ 2 ) where µ was the true age 278 label and σ was set to 10 empirically, allowing us to maintain relative information between 279 bins. The value of 10 was chosen as, when used in normal training, this minimised the mean 280 Figure 5 : Network architecture used for unlearning with segmentation. X p represents the input images used to evaluate the primary task with y being the main task label segmentations. X u are the input images used for unlearning scanner information with domain labels d u . The domain discriminator for unlearning can be attached from A, B or the two in combination. If it is attached from A and B together, the first fully connected layers are concatenated together to produce a single feature representation. (2095 training, 937 testing) and healthy subjects from the OASIS dataset [29] (813 training, 295 217 testing). As before, these were resized to 128 × 128 × 128 voxels and the intensity values 296 normalised, but then they were split into 2D slices so that we trained a 2D network. Multi-297 class Dice Loss was used as the primary task loss function. All experiments were completed 298 on a V100 GPU and were implemented in Python (3.6) and Pytorch (1.0.1). As can be seen in Fig. 5 , the general structure is identical to that used for age prediction, 301 apart from the location of the domain predictor. In the case of the age prediction task, the unlabelled data points cannot be used to evaluate the main task, they can be used for 324 scanner unlearning. No changes to the architecture are required; rather, we simply evaluate the main task for 326 those data points for which we have main task labels and use all data points for unlearning 327 such that the overall method minimises: where X p is the subset of X for which we have main task labels y p available and the full 329 dataset X is used in unlearning scanner information. 330 We explore the effect of increasing numbers of data points on the final segmentation, Figure 6 : a) T-SNE plot of the fully connected layer in Θ repr before unlearning. It can be seen that the domains can be almost entirely separated, except for two data points grouped incorrectly, showing that data from each scanner has its own distinct distribution. b) T-SNE plot of the fully connected layer in Θ repr after unlearning. It can be seen that, through the unlearning, the distributions become entirely jointly embedded. In Fig. 6 , a T-SNE plot [28] can be seen, allowing visualisation of the frozen feature rep-347 resentation Θ repr . It can be seen that, before unlearning, the features produced are entirely 348 separable by scanner, but after unlearning the scanner features become jointly embedded 349 and so the feature embedding is not informative of scanner. The T-SNE demonstrates that without unlearning, not only are there identifiable scanner 351 effects but that they also affect the age predictions. On the other hand, it can be seen that 352 we are able to remove scanner information using our unlearning technique, such that the data 353 points for all three scanners share the same embedding. This is confirmed by the scanner 354 classification accuracy being almost random chance after unlearning has been completed. It can also be seen that unlearning does not decrease substantially the performance on It can also be seen that the unlearning process creates a feature space that generalises results can be seen in Table 2 . For both normal training and unlearning, there was a large 393 decrease in performance using the balanced datasets, but the decrease was less pronounced 394 when using unlearning than when using normal training. Using balanced datasets made no 395 difference to the ability to classify scanner: with normal training the classifier was still able 396 fully to distinguish between scanners; with unlearning, the classification was almost random 397 chance. Therefore, any advantage gained from having balanced datasets is outweighed by datasets. It can be seen that the value of β has no effect on the achieved MAEs across these 410 values from Fig. 7 . Therefore, the segmentation performance is robust to the choice of β 411 and the value chosen can be selected to maximise the stability of training without impacting 412 on the final prediction values. Stability could also be controlled by using different learning 413 rates for each stage, however this was found to be harder to tune in practice. Finally, we explored the choice of the batch size used. For training, the largest batch 415 that could fit into memory was used; here the effect of smaller batches is explored. Figure 8 We assessed the performance of the network training with biased datasets curated from 427 subsets of the Biobank and OASIS datasets. We considered the cases of a 5-year, 10-year and 428 15-year overlaps where the smaller the overlap, the harder the task. The trained networks are 429 tested on the same test sets used above, with subjects across the whole age range included. The results can be seen in Table 3 where we compare normal training, naïve unlearning 431 on the whole dataset, and unlearning on only the overlapping age range. It can be seen 432 that for all three degrees of overlap, standard training leads to much larger errors than 433 unlearning and that unlearning on only the overlap gives a lower error than unlearning on 434 all data points. Plots of the MAEs with age can be seen in Fig. 10 on the testing data, probably indicating that the features removed also encode some age 447 information and so generalise less well across the whole age range. unlearning was reduced to 6.42%, where random chance was 6.25%. Therefore, we can 454 see that scanner information is present in the feature embeddings, even when harmonised 455 protocols are used. These results therefore show that the framework can be applied to many sites with no 457 changes needed. It also shows that the framework is applicable to lower numbers of subjects, 458 with some of the sites having as few as 27 subjects for training. We thus foresee that the 459 framework should be applicable to many studies. The effect of removing sex information as an additional confound in addition to harmon-462 ising for scanner was investigated. In Fig. 12 it can be seen that there is no significant 463 effect on the MAE results when removing sex information in addition to scanner informa-464 Figure 11 : MAE results broken down by site for the ABIDE data, comparing normal training to unlearning. tion. Unlearning sex information had no substantive effect on the ability to remove scanner 465 information with the scanner classification accuracy being 48% and 49%, respectively. The 466 sex classification accuracy was 96% before unlearning and 54% after unlearning. Therefore, 467 we can remove multiple pieces of information simultaneously and, so long as the information 468 we wish to unlearn does not correlate with the main task, we can do so without significantly 469 reducing the performance on the main task. scanner given that 80% of the subjects in the OASIS dataset are male and 80% of the subjects 473 in the Biobank dataset are female. Table 4 shows the comparison of normal training on these 474 datasets compared to unlearning sex on all data points and unlearning sex on a subset with 475 balanced numbers of each sex for each scanner. The full testing set was still used. It can be seen that there is little difference between unlearning on all, or just a balanced 477 subset, compared to the normal training baseline. It can also be seen that neither method 478 affects the network's ability to unlearn scanner or sex information and so there is no need 479 to change the unlearning training from the standard case. Finally, we considered the case where sex -the confound we wish to unlearn -is highly 482 correlated with age -the main task -such that 80% of the subjects below the median age 483 are female and 80% of the subjects over the median age are male. Again, normal training 484 was compared to unlearning on the whole dataset and to unlearning on a curated subset, 485 with uniform distributions of sex with age for each scanner. Testing was again performed 486 on the full testing datasets. It was found that it was not possible to unlearn on the whole distribution as it caused 488 the training to almost (after approximately 5 epochs) become unstable almost immediately, 489 with the training and validation loss exploding -before we were able to unlearn sex and 490 scanner information. The results from unlearning scanner can be seen in Table 5 where it 491 can be seen that by unlearning only on a subset of data points with equal sample numbers 492 we can almost entirely unlearn both scanner and sex information. Figure 13 shows a T-SNE 493 [28] of the feature embeddings where it can be seen that, before unlearning, the data was 494 largely separable by scanner and sex, and that, after unlearning, these are indistinguishable. (a) (b) Figure 13 : a) T-SNE plot of the fully connected layer in Θ repr from training without unlearning, having trained on a dataset where sex correlates highly with age. It can be seen that the scanners can still be almost entirely separated, except for two data points grouped incorrectly (located within the cluster of light blue data points), showing that data from each scanner has its own distinct distribution and that the data points can also be entirely split by sex for the Biobank data. b) T-SNE plot of the fully connected layer in Θ repr after unlearning. It can be seen that, through the unlearning, the distributions become entirely jointly embedded in terms of both scanner and sex. For the removal of age as a confound, we used sex prediction as the main task. We 497 achieved an average of 96.3% on the sex prediction task before unlearning and 95.9% after 498 unlearning. As before, we were able to unlearn scanner information, reducing the scanner 499 classification accuracy from 100% to 53%. Figure 14 a) shows the the averaged softmax 500 labels from the age prediction task with normal training, where it can be seen that there is 501 a large degree of agreement between the true labels and the predicted labels, showing that 502 we are able to learn age using the continuous labels and KL divergence as the loss function. 503 We achieved MAE values of 3.26 ± 2.47 for Biobank and 4.09 ± 3.46 for OASIS with normal 504 training in this manner. 505 Figure 14 b) shows the softmax labels after unlearning. It can be seen that the predicted 506 labels no longer follow the same distribution as the true labels and that they are distributed 507 around the random chance value indicated by the dotted line. It can also be seen that there 508 is no trend towards the true age value, indicating that a large amount of the age information 509 has been removed. After unlearning, the MAEs for the age prediction task were 17. Fig. 5 ) A) at the final convolutional layer, B) at the bottleneck, and A+B) the combination of the two locations. The scanner classification accuracy was the accuracy achieved by a separate domain predictor using the fixed feature representation at the final convolutional layer. Random chance is given in brackets. The results from comparing the location of the domain predictor can be seen in Table 6 . 517 We compare the Dice scores and consider the scanner classification accuracy. It can be seen The results for training on the combination of datasets with normal training and un-524 learning can be seen in Table 7 , averaged across tissue type, and in Fig. 15 with labels were used for evaluating the main task but all of the data points were still used for 533 unlearning scanner as these loss functions do not require main task labels to be evaluated. The results can be seen in Fig. 16 The results also show that we can use the unlearning scheme even when there is a 578 strong relationship between the main task label and the acquisition scanner. This could be 579 Figure 16 : Dice scores for the three different tissue types for the OASIS data with increasing numbers of OASIS training slices, comparing both normal training and unlearning with the full Biobank dataset used throughout. For clarity, the x axis is not plotted to scale. especially useful when combining data between studies with different designs and with very 580 different numbers of subjects from each group. As we can perform unlearning on a different 581 dataset to the main task, we have the flexibility to apply the method to a range of scenarios. At the limit where there exists no overlap between the datasets' distributions, unlearning 583 scanner information would be highly likely to remove information relating to the main task. This would then only be able to be solved by having an additional dataset for each scanner, 585 acquired with the same protocol, and so represents a potential limitation of the method. The extension to the ABIDE dataset, however, shows the applicability of the method to 587 many small datasets, and therefore this may be solvable using a set of small datasets which 588 together span the whole distribution. 589 We have also shown that the approach can be extended to allow us to remove additional 590 confounds, and have demonstrated a way that this could also be extended to allow us to 591 remove continuous confounds such as age. For each confound to be removed, we require two 592 additional loss functions and two additional forward and backward passes. Therefore, the 593 training time will increase with each additional confound, presenting a potential limitation. The computational aspects of this research were supported by the Wellcome Trust Core Example training graphs from the age prediction task with three datasets. Only the tation. As confirmed across the literature, it was unstable to train, limiting its performance. The results from training can be see in Table 8 where it can be seen that the scanner classi-809 fication accuracy was much lower than random chance, indicating that scanner information 810 remained in the feature representation. It can also be seen that the performance on the 811 main task was reduced, with the domain adaptation leading to lower performance compared 812 to normal training. Sotiropou-638 los Unsupervised domain adaptation by backpropagation. International Con-690 ference on Machine Learning Domain-adversarial training of neural networks Neuroharmony: A 697 new tool for harmonizing volumetric mri data from unseen scanners Machine learning with 700 multi-site imaging data: An empirical study on the impact of scanner effects Fis-705 chl, B.: Reliability of mri-derived measurements of human cerebral cortical thickness: The ef-706 fects of field strength, scanner upgrade and manufacturer Multi-site harmonization of diffusion mri 709 data via method of moments Brit-712 son The alzheimer's disease neuroimaging initia-714 tive (adni): Mri methods Adjusting batch effects in microarray expression data using em-717 pirical bayes methods Reliability in multi-site structural mri studies: Effects 720 of gradient non-linearity correction on phantom and human data Overcoming the disentanglement vs reconstruction trade-off via jacobian supervision Viualizing data using t-sne Open access 727 series of imaging studies (oasis): Cross-sectional mri data in young, middle aged, nonde-728 mented, and demented older adults Inter-site and inter-scanner diffusion mri data harmonization Harmonizing diffusion mri data across multiple sites and scanners Scanner invariant representations for diffusion mri 742 harmonization Pytorch: An imperative style, high-performance deep 746 learning library Harmonization of large mri datasets for 750 the analysis of brain imaging patterns throughout the lifespan U-net: Convolutional networks for biomedical image segmenta-753 tion Beyond sharing weights for deep domain adap-756 tation Very deep convolutional networks for large-scale image recognition Harmonization of diffusion mri datasets with adaptive dictio-761 nary learning. bioRXiv Uk biobank: An open access resource for identifying the causes of a wide range of 765 complex diseases of middle and old age Deep coral: Correlation alignment for deep domain adaptation. European Con-767 ference on Computer Vision Effect of scanner in longitudinal studies of brain 769 volume changes Effects of study design in multi-scanner voxel-based morphometry 772 studies A survey on deep transfer learning Unbiased look at dataset bias Simultaneous deep transfer across domains and tasks Domain adaptation for alzheimer's disease diagnostics Detect and correct bias in multi-site neuroimaging datasets Intensity warping for multisite mri harmoniza-784 tion Statistical harmonization corrects site effects in functional connectivity measurements 787 from multisite fmri data Segmentation of brain mr images through a hidden markov random 789 field model and the expectation maximization algorithm Harmonization of infant cortical thickness 792 using surface-to-surface cycle-consistent adversarial networks Conference Proceedings Unpaired image-to-image translation using cycle-consistent 795 adversarial networks