key: cord-0569919-87r1o8z1 authors: Xu, Xuan; Prasanna, Prateek title: Brain Cancer Survival Prediction on Treatment-na ive MRI using Deep Anchor Attention Learning with Vision Transformer date: 2022-02-03 journal: nan DOI: nan sha: e97b2040505372a9327eb2dc55908aa1c5bb022a doc_id: 569919 cord_uid: 87r1o8z1 Image-based brain cancer prediction models, based on radiomics, quantify the radiologic phenotype from magnetic resonance imaging (MRI). However, these features are difficult to reproduce because of variability in acquisition and preprocessing pipelines. Despite evidence of intra-tumor phenotypic heterogeneity, the spatial diversity between different slices within an MRI scan has been relatively unexplored using such methods. In this work, we propose a deep anchor attention aggregation strategy with a Vision Transformer to predict survival risk for brain cancer patients. A Deep Anchor Attention Learning (DAAL) algorithm is proposed to assign different weights to slice-level representations with trainable distance measurements. We evaluated our method on N = 326 MRIs. Our results outperformed attention multiple instance learning-based techniques. DAAL highlights the importance of critical slices and corroborates the clinical intuition that inter-slice spatial diversity can reflect disease severity and is implicated in outcome. Clinical Motivation. Glioblastoma (GBM), an aggressive form of brain cancer, has a poor prognosis with the median survival around 12-15 months [1] . This has led to an unmet clinical need for personal risk prediction approach at an early stage so that treatment regimens may be tailored according to individual risk profiles instead of resorting to a one-size-fits-all approach. MRI is a key tool in brain cancer diagnosis and treatment planning. Previous works have leveraged multi-parametric MRI to detect the GBM saliency [2] . Most existing imaging models have relied on radiomic analysis [3, 4] , which is sensitive to image pre-processing and often hard to reproduce. Though some of this heterogeneity is manifested on imaging, the spatial diversity and the interaction between different sub-regions of a tumor are not fully explored from a phenotype perspective. Survival analysis studies are limited to understanding of intra-tumoral and peri-tumoral heterogeneity as slice-wise statistics [5] without taking into account the intra-tumor heterogeneity reflected as inter-slice and intra-slice spatial diversity on MRI (illustrated in Figure 1 ). Technical Motivation. Several works have leveraged pretrained models to extract imaging features from radiological scans [6] . Pretrained models such as VGG [7] and ResNet [8] are usually trained on ImageNet [9] . Pretrained AlexNet models have been used in a transfer learning setting to detect abnormalities in brain MRI [10, 11] . The visual composition of medical images, especially radiology scans, are fundamentally different from natural images. Pre-trained representations may not effectively capture subtle local variations prevalent in MRI scans, more so in brain tumors [12] . Another challenge stems from the multiple views of 3D MRI scans and multiple cross-sections in a given view. Class label is generally assigned at an MRI-level and not at a slice-level. Due to the intra-tumor heterogeneity in cancers, it is unreasonable to assign the same patient label to different slices. Previous works on medical imaging leveraging convolutional neural networks focus on local variations in image patterns [13] . Vision transformer (ViT) [14] models long range dependencies and shows how a pure transformer can perform well on image classification tasks. This closely follows the original Transformer [15] and has been widely used in image classification [16] . Compared with computer vision tasks involving natural images, MRIs are usually smaller datasets, typically including several hundred images while ViTs require large scale datasets for training. To address these challenges, we first utilize slices from axial, coronal, and sagittal planes to extract imaging representations for each patient. Unlike existing prognostic methods that consider only axial planes [3] , the slices from three planes enable partial retention of 3D information and a more comprehensive feature extraction with transfer learning. Secondly, we use a vision transformer [14] to evaluate the spatial relationships between different small regions of the tumor and its periphery within one slice. Most importantly, we propose a deep anchor attention learning algorithm to exploit the inter-slice spatial diversity. To mimic the radiological assessment workflow, slices with largest tumor area are designated as anchor slices. This strategy provides a way to aggregate the slice-level representations to patient-level representation using a latent-space comparison between anchor slices and other slices. Our main contributions are as follows: • Intra-slice spatial diversity in MRI is captured by a selfattention mechanism. We leverage a vision transformer to quantify the spatial diversity between small regions within a given slice, as shown in Figure 1 (a). • A Deep Anchor Attention Learning (DAAL) algorithm is proposed to provide an efficient aggregation strategy for slice-level representations. We adapt DSMIL [17] to highlight the critical slices in a 3D MRI. Our experimental results corroborate the clinical intuition that intra-tumor heterogeneity reflected on imaging as intraand inter-slice spatial diversity (as described in Figure 1 ) can reflect disease severity and is implicated in prognosis. Consider a set of N patients, X i , i = {1, 2, ..., N } with label (t i , δ i ). δ i = 1 indicates a death event and δ i = 0 corresponds to a censored event. t i is the time to event or censoring. The overarching goal is to predict the survival risk r i for each patient. The workflow is illustrated in Figure 2 (a). First, we utilize slices from the three different planes, i.e the axial, coronal, and sagittal planes, to extract imaging features for each patient as shown in Figure 2 (b). We analyze the 3D volumes as 2D slices in multiple planes. For a given MRI, we calculate the number of tumor pixels and determine the slice indices anchor x , anchor y , anchor z corresponding to the largest tumor area in the sagittal, coronal, and axial planes, respectively. The corresponding slices are designated as anchor slices s anchorx , s anchory , and s anchorz , respectively. Based on these anchor slices, we select neighbor slices as [anchor x − k x1 , anchor x + k x2 ], [anchor y − k y1 , anchor y + k y2 ], [anchor z − k z1 , anchor z + k z2 ] near the anchor slices where k q1 , k q2 represent the number of left neighbors and right neighbors around s anchorq , q ∈ (x,y,z). For each selected slice, we crop a bounding box around the tumor region then resize it into 224 × 224. We now obtain a 2D slice list including the cropped slices in three planes for each MRI volume. The number of slices K, K = k x1 +k x2 +k y1 +k y2 +k z1 + k z2 + 3, can influence the fraction of tumor selected. Higher the number of neighbor slices around the anchor slices, more is the volume of selected tumor region until the whole tumor is accounted for. The ratio = # tumor pixels in selected slices # all tumor pixels is shown in Figure 2 (c). With increase in K, the tumor pixels included across these slices also increases .We evaluate whether our choice of K will influence the C-index. We finetune a model on dataset D 1 [18] then utilize the trained model to extract features on dataset D 2 . D 1 and D 2 are described in Section 3.1. Images in D 1 are fed into the ImageNet pretrained ViT for classification. For each image, the tumor region is cropped and resized to 224 × 224 as the input to the pretrained model [19] . The patch size is 16 × 16. The input channel is changed to 1 and the number of classes is changed to 3 according to the D 1 dataset. 500 epochs are training with Adam optimizer. We achieve a validation accuracy of 0.94 in classifying glioma, meningioma, and pituitary tumor. Consequently, we obtain a ViT finetuned on MRI slices which is then leveraged as a feature extractor for survival prediction. Each slice from the slice list is fed into the extractor to get a slice-level representation. A list of slice-representations is obtained for each patient. This section discusses the aggregation of slice-level representation into patient-level representation. Suppose one patient has K slices, the corresponding slice-level representation list is H = {h 1 , h 2 , ..., h K }, and the target is to generate the patient-level representation. As an aggregation strategy, we propose a Deep Anchor Attention Learning (DAAL) method to assign weights to each slice-level representation. Each slice-level representation h s is transformed into query q s ∈ R D×1 and information v s ∈ R L×1 with: where W q and W v are the weight matrices. The queries are to be matched between slices and the anchor slices. Information from slice-level representation are extracted using information vector v s . The anchor slice representations are h anchorx , h anchory , h anchorz and their corresponding query is q anchorp (p = x, y, z).The distance measurement U between an arbitrary representation to the anchor slice repre- sentation is defined as: where ·, · represents the inner product of two vectors. It measures the similarity between the other queries (including anchor queries) and the anchor query. We set three distances U p , p = (x, y, z) according to the three anchor slices. The patient-level representation is the weighted element-wise sum of the information vectors v i of all slice-level representations with the U p (h s , h anchorp ) as weights: We get three patient-level representations b x , b y , and b z , which are then fed into fully connected layers to output the risk r ix , r iy , r iz . Two experiments are performed to evaluate our methods. First, we use only r ix as the patient risk score (DAAL-single). Second, we use max(r ix , r iy , r iz ) as the patient risk score. We leverage the negative log partial likelihood function as the loss function in Equation 4 [13] . Dataset 1 (D 1 ) is a brain tumor dataset which includes 3064 T1-weighted contrast-enhanced images with three kinds of brain tumors, namely, glioma, meningioma, and pituitary tumor [18] . It provides 2D slices in three planes along with associated tumor mask and tumor class. We leverage (D 1 ) for ViT fine tuning. To validate our performance, we utilized another dataset D 2 comprising N=326 MRIs and corresponding tumor masks from the BraTS2020 challenge [20, 21, 22] . Available sequences include pre-contrast T1, post-contrast T1-weighted (T1Gd), T2-weighted, and T2 Fluid Attenuated Inversion Recovery (T2-FLAIR) volumes. The tumor annotations comprise the enhancing tumor, the peritumoral edema, and the necrotic and non-enhancing tumor core. We validate our proposed method on T1Gd scans. 49 patients are set aside as the test set according to the ratio of censored data. The remaining 277 patients are split into 5 folds for cross validation. Consequently we have five models in the 5-fold cross validation setting and test them on the held out test set. Concordance index (C-index) is used to evaluate the prognostic models. A c-index > 0.65-0.7 is generally considered optimal in survival analysis. The hazard ratio (HR) is the ratio of the hazard rates corresponding to the conditions described by two levels of an explanatory variable.We use HR to compare the probability of death events in two groups (low risk and high risk groups). For each fold, we obtain a risk score for the patients in the test set. The two risk groups are identified by the median risk score from the training set. We then use majority voting for final risk group assignment. Many prior image-based survival analysis methods have focused on radiomics. We extract 336 textural radiomic measures from the intratumoral regions using the provided tumor segmentation [23] of the T1-Gd MRIs follwed by a Principal Component Analysis [24] . We use the Cox model with 10 principal components as input. The C-index is 0.5796 ± 0.0309. ResNet18. In order to evaluate the difference between features from a ViT and other CNN models, we leverage a pretrained ResNet18 [8] finetuned on a brain tumor dataset to keep the setup consistent with our ViT experiments. 512 features are extracted for each slice followed by our proposed Table 1. C-index with Vision transformer feature extractor Number of slices 8 16 24 32 40 48 64 72 80 The Cox proportional hazards model is the most popular model in survival analysis [4] . After computation of the ViT features H = {h 1 , h 2 , ..., h K }, we calculate the mean H mean and maximum H max across the slice-level representations. H mean and H max also work as the patientlevel representations and are fed into the Cox model. We also feed the mean and maximum values of ResNet18 representation to a Cox model for comparison. DeepSurv. DeepSurv [25] is another commonly used survival prediction model. We use the H mean and H max to be the patient level-representation and feed them into the Deep-Surv model from the pysurv package [26] . DeepAttnMISL model. DeepAttnMISL [13] is a deep attention multiple instance learning method first proposed for whole slide tissue images. We use the attention MIL [27] method from DeepAttnMISL to assign the weights to ViT slice-level representations and ResNet18 slice-level representations to compare the attention MIL with DAAL. Table 1 shows the results of our methods (DAAL-single, DAAL-multiple) and attention MIL using vision transformer feature extractor. We also consider the aggregation methods with the mean and max values of the slice-level representation with Cox model. We see that in most cases, our methods show higher average C-index values compared with attention MIL and Cox models. More importantly, the proposed methods show higher C-index values even with small number of slices as compared to attention MIL and Cox models, imply- ing that we can predict survival risk reliably using relatively less contextual information. Besides ViT, we also evaluate these methods with a ResNet18 feature extractor. Table 2 show the c-indices with ResNet18 features. From slice number K = 32, DAALsingle shows higher c-indices compared to all the other methods with ResNet18 features. We compare the deep learning methods (Attention MIL, DAAL-single, and DAAL-multiple) in Figure 3 . As may be observed, using ViT we get higher c-indices as compared to ResNet18 at identical slice number settings. The results suggest intra-slice spatial diversity is inherently prognostic and can be captured by self attention. In this work, we proposed a deep anchor attention learning method with a vision transformer and demonstrated its efficacy in predicting overall survival in brain cancer patients using treatment naïve MRI. The intra-slice spatial diversity is captured by the vision transformer. The DAAL method is used as an aggregation strategy; this highlights the important slices and estimates the inter-slice spatial diversity. The approach implicitly mimics the experts' diagnostic process, where radiologists usually observe slices with more well defined and large tumor components first and then focus their attention on the other slices (may be considered secondary, yet valuable to the ultimate diagnosis) to evaluate disease severity. Experiments show that DAAL leads to an improvement in survival prediction over state-of-the-art methods. Our methods achieve higher c-indices with lower contextual information. In the future, we may consider anchor attention mechanisms in 3D models and complementing T1-Gd with other sequences (T2, FLAIR) when available. This research study was conducted retrospectively using open access human subject data (BraTS 2020). Additional approval was not required as confirmed by [20, 21, 22 ]. No funding was received for conducting this study. The authors have no financial or non-financial interests to disclose. Long-term survival with glioblastoma multiforme A novel gbm saliency detection model using multi-channel mri A deep learning-based radiomics model for prediction of survival in glioblastoma multiforme Multi-habitat radiomics unravels distinct phenotypic subtypes of glioblastoma with clinical and genomic significance Radiomic features from the peritumoral brain parenchyma on treatment-naive multiparametric mr imaging predict long versus short-term survival in glioblastoma multiforme: preliminary findings A comprehensive study on classification of covid-19 on computed tomography with pretrained convolutional neural networks Very deep convolutional networks for large-scale image recognition Deep residual learning for image recognition Imagenet: A large-scale hierarchical image database Convolutional neural networks for multi-class brain disease detection using mri images Imagenet classification with deep convolutional neural networks Transfusion: Understanding transfer learning for medical imaging Whole slide images based cancer survival prediction using attention guided deep multiple instance learning networks An image is worth 16x16 words: Transformers for image recognition at scale Attention is all you need A survey on visual transformer Dual-stream multiple instance learning network for whole slide image classification with self-supervised contrastive learning Pytorch image models The multimodal brain tumor image segmentation benchmark (brats) Advancing the cancer genome atlas glioma mri collections with expert segmentation labels and radiomic features Identifying the best machine learning algorithms for brain tumor segmentation, progression assessment, and overall survival prediction in the brats challenge Computational radiomics system to decode the radiographic phenotype Principal component analysis Deepsurv: personalized treatment recommender system using a cox proportional hazards deep neural network PySurvival: Open source package for survival analysis modeling Attention-based deep multiple instance learning