key: cord-0518187-jmkato54 authors: Bergner, Benjamin; Rohrer, Csaba; Taleb, Aiham; Duchrau, Martha; Leon, Guilherme De; Rodrigues, Jonas Almeida; Schwendicke, Falk; Krois, Joachim; Lippert, Christoph title: Interpretable and Interactive Deep Multiple Instance Learning for Dental Caries Classification in Bitewing X-rays date: 2021-12-17 journal: nan DOI: nan sha: 173d15cb47a13795e9b9a5953ecb7e0f5f4bcc23 doc_id: 518187 cord_uid: jmkato54 We propose a simple and efficient image classification architecture based on deep multiple instance learning, and apply it to the challenging task of caries detection in dental radiographs. Technically, our approach contributes in two ways: First, it outputs a heatmap of local patch classification probabilities despite being trained with weak image-level labels. Second, it is amenable to learning from segmentation labels to guide training. In contrast to existing methods, the human user can faithfully interpret predictions and interact with the model to decide which regions to attend to. Experiments are conducted on a large clinical dataset of $sim$38k bitewings ($sim$316k teeth), where we achieve competitive performance compared to various baselines. When guided by an external caries segmentation model, a significant improvement in classification and localization performance is observed. Dental caries is the most prevalent health disease, affecting more than three billion people worldwide (Kassebaum et al., 2017) . For diagnosis, clinicians commonly analyze bitewing radiographs (BWRs), which show the maxillary and mandibular teeth of one side of the jaw. However, the assessment of caries in bitewings is associated with low detection rates. For example, Schwendicke et al. (2015) reported domain expert-level sensitivity of only 24% (21%-26%, 95% CI) for the detection of both initial and advanced carious lesions. The challenging nature of caries detection and the growing quantity of dental data motivate the use of deep learning techniques for this task. In order to support dentists, such models must overcome various technical challenges, as follows: (1) Diagnosing caries is a low signal-to-noise ratio problem. That is, lesions may occupy only few pixels in the image. Standard convolutional neural networks (CNN) have been shown to struggle in this setting (Pawlowski et al., 2020) . In contrast, models using attention are designed to focus on important regions while ignoring the prevalent background (Katharopoulos and Fleuret, 2019) . (2) Caries classification is a multiple instance learning (MIL) problem (Dietterich et al., 1997) . That is, an image is considered positive if at least one carious lesion is present and negative if and only if no lesion is present. In this context, an image is described as a bag of image region features called instances; see Carbonneau et al. (2018) for an introduction. (3) BWRs contain multiple teeth, and each may be affected by caries. However, classification outputs are restricted to a single probability score and thus lack interpretability (Zhang and Zhu, 2018) . A supporting model should indicate where each lesion is located so that its correctness can be verified. (4) Optimal decision support is receptive to feedback (Holzinger, 2016) . Beyond only outputting information about the occurrence of caries (learned from weak labels), a dentist or teacher (Hinton et al., 2015) could interact with the model by providing strong labels (such as segmentation masks) to improve performance. We present Embedding Multiple Instance Learning (EMIL), which is an interpretable and interactive method that fulfills above considerations. EMIL extracts 3D patches from a spatial embedding resulting from any CNN. Each patch may show caries and is classified individually, and all predictions together form a heatmap of local probabilities, notably without access to patch labels. An attention mechanism weighs local predictions and aggregates them into a global image-level prediction. Besides standard classification, the method enables (but does not rely on) the inclusion of dense labels. Although EMIL adds important capabilities for the present use case, classification of dental caries, it is a simple adaptation to common CNNs with low computational cost that translates to other diagnosis tasks. We evaluate performance and interpretability using a large clinical bitewing dataset for imageand tooth-level classification, and show the positive impact of including strong tooth and caries labels. Our code is available at: https://github.com/benbergner/emil. Recently, several caries prediction models have been published. Tripathi et al. (2019) used a genetic algorithm on 800 BWRs and reported an accuracy of 95.4%. Srivastava et al. (2017) trained a 100+ layer CNN on 2,500 BWRs and reported an F-score of 70%. Megalan Leo and Kalpalatha Reddy (2020) trained a CNN on 418 cropped teeth from 120 BWRs and achieved an accuracy of 87.6%. Kumar and Srivastava (2018) proposed an incremental learning approach and trained a U-Net on 6,000 BWRs, which yielded an F-score of 61%. Cantu et al. (2020) trained a U-Net on 3,686 BWRs and reported tooth-level accuracy and F-score of 80% and 73%, respectively. Bayraktar and Ayan (2021) trained YOLO on 800 bitewings and reported an AUC score of 87%. MIL is commonly used for the classification of microscopic images in which, e.g., a single cancer cell positively labels a bag (Kraus et al., 2016; Sudharshan et al., 2019) . MIL has also recently been applied in radiology, e.g. Han et al. (2020) screened chest CTs for COVID-19 and Zhou et al. (2018) detected diabetic retinopathy in retinal images. To the best of our knowledge, this is the first application of MIL to the field of dental radiology. Instance representations are commonly created from patches extracted from the input image (Xu et al., 2014) , but can also be extracted from a CNN embedding (Pawlowski et al., 2020; Dosovitskiy et al., 2021) . Furthermore, one can distinguish between approaches predicting at the instance (Wu et al., 2015; Campanella et al., 2018) or bag-level (Wang et al., 2018; Ilse et al., 2018) . Our approach combines the extraction of instances as overlapping patches from a CNN embedding with the classification of individual instances that are aggregated with a constrained form of attention-based MIL pooling. Below, we describe our proposed model, the creation of a local prediction heatmap, and a method to incorporate strong labels. A schematic of the architecture is shown in Figure 1 . We consider an image X ∈ R H X ×W X ×C X as input, with H X , W X , C X being height, width and number of channels, and assign a binary label y ∈ {0, 1}. First, a(ny) convolutional backbone computes a feature map U ∈ R H U ×W U ×C U : (1) For this purpose, a sliding window is used with kernel size (H P , W P ) and stride (H S , W S ). Each patch is spatially pooled, resulting in a feature matrix P ∈ R K×C U . We use average pooling, which is most prevalent in image classification (Lin et al., 2014) : Both patch extraction and pooling are implemented by an ordinary local pooling operation. Each patch may show a carious tooth region and is thus classified independently by a shared fully-connected layer parametrized by o ∈ R C U . This is followed by a sigmoid operator, which outputs classificationsỹ ∈ R K holding class probabilities for each patch: We use a patch weight vector w ∈ R K×1 to focus on carious lesions while neglecting background and non-caries tooth regions. The image-level predictionŷ is computed as: The denominator ensures that at least K min patches are attended to, and provides a way to include prior knowledge about the target's size. For caries classification, a single positive patch should lead to a positive prediction, so we set K min = 1 (see Appendix D for more details). The weight of each patch is determined by its own local representation. We use a variant of the gated attention mechanism (Ilse et al., 2018) , which is a two-branch multilayer perceptron parametrized by A ∈ R C U ×D , B ∈ R C U ×D and c ∈ R D×1 , with D hidden nodes: Compared to the original formulation using softmax, we employ sigmoid as outer function, and normalize Eq. 4 accordingly. This makes the weights independent of each other and allows to ignore all patches, which is useful for classifying the negative class. A heatmap Mỹ is constructed with each element corresponding to a local predictionỹ k . For visualization, the heatmap is interpolated and superimposed on the input image. Similarly, another heatmap M w is built from patch weights w. While Mỹ shows local predictions, M w indicates which areas are considered for global classification. Note that the locations in Mỹ and M w can be interpreted as probabilities, a property that attribution methods lack. The probability for any group of patches (e.g., a tooth) can be calculated with Eq. 4 by updating patch indices in both sums. Furthermore, note that EMIL is optimized for faithfulness (Alvarez Melis and Jaakkola, 2018). That is, one can tell exactly by how mucĥ y changes when removing any patch i, which is − w iỹi K min if k w k −w i < K min and 0 otherwise. For example, if K min = 1 and 2 caries patches are present,ŷ shouldn't change by removing one caries patch, which aligns with the standard MIL assumption (Foulds and Frank, 2010) . Optionally, learning can be guided by providing additional labels, such as segmentation masks. For example, a dentist could interactively correct errors/biases; while a data scientist might want to incorporate dense (and expensive) labels for a subset of the data. To create patch-wise labels y ∈ R K , a downscaled binary annotation mask is max-pooled with kernel size/stride from Sect. 3.1, and vectorized. Then, to compute compound loss L for an image, both patch and image-level cross-entropy losses = [L image , L patch ] are weighted and added: Due to class imbalance in caries masks, the network easily fits the background class and we observe that L patch L image . Thus, L image dominates the compound loss and diminishes the benefit of strong labels. To mitigate this problem, coefficient α is introduced to dynamically scale each partial loss to the magnitude of the largest one. Note that α is transformed into a constant so that the partial losses are detached from the computational graph. Below, we describe the experiments and answer the following research questions: (1) How well can EMIL predict caries in BWRs and cropped tooth images? (2) Can it highlight caries and provide clinical insight? (3) To what extent do strong labels improve performance? The dataset stems from three dental clinics in Brazil specialized in radiographic and tomographic examinations. The dataset consists of 38,174 BWRs (corresponding to 316,388 cropped tooth images) taken between 2018 and 2021 from 9,780 patients with a mean (sd) age of 34 (14) years. Tooth-level caries labels were extracted from electronic health records (EHRs) that summarize a patient's dental status. Next to these EHR-based ground truth labels, which are associated with uncertainties and biases (Gianfrancesco et al., 2018) , a random sample of 355 BWRs was drawn, and annotated with caries masks by 4 experienced dentists, yielding 254 positive and 101 negative cases. These annotations were reviewed by a senior radiologist (+13 years of experience) to resolve conflicts and establish a test set. We consider caries classification on BWR and tooth level and use stratified 5-fold crossvalidation with non-overlapping patients for training and hyperparameter tuning. Due to class imbalance, the balanced accuracy is used as stopping criterion. In the tooth-level task, both class terms in L image are weighted by the inverse class frequency to account for class imbalance. Results are reported on the hold-out test set as average of the 5 resulting models with 95% CI. Binary masks from two teacher models are used to simulate interactivity: (1) a tooth instance-segmentation model ( , unpublished) pointing at affected teeth and (2) a caries segmentation model ( , Cantu et al. (2020) ). Note that these models are subject to errors and do not replace class labels, but only guide training; if a segmentation contradicts the classification label, it is discarded. More training details are described in Appendix C. Several baselines are used to show competitive performance. ResNet-18 (He et al., 2016) serves as backbone for all methods. EMIL makes only few changes to the default CNN, so we also employ ResNet-18 as a baseline. In order to study the effect of embedding-based patch extraction, we compare to DeepMIL (Ilse et al., 2018) , which operates on patches cropped from the input image. As patch sizes, 32 and 128 px with 50% overlap are used. To show the effect of our patch weighting and aggregation approach, the attention mechanism is replaced by the max operator, which is common in instance-based MIL (Amores, 2013) . However, this did not fit the training data. A more powerful baseline is the hybrid version of the Vision Transformer (ViT) (Dosovitskiy et al., 2021) with a single encoding block. As in EMIL, patches are extracted from the output of the last conv layer, and we found large overlapping patches to be beneficial. The hybrid base and pure attention versions were not included because they performed worse or did not fit the training data. For the evaluation of the interactive settings, a simple baseline consists in attaching a 1x1 conv layer (Lin et al., 2014) with a single output channel, kernel and stride of 1 and zero padding, to the last ResNet-18 encoder block to output a segmentation map. As a stronger baseline, we adapt Y-Net (Mehta et al., 2018) , which consists of a U-Net (with a ResNet-18 encoder) and a standard output layer attached to the last encoding block for classification. In all interactive settings, the same loss functions are used (Eq. 6-8). See also Appendix B. Table 1 and Fig. 2 summarize the results. In bitewings, EMIL shows highest mean accuracy, F-score and sensitivity across settings (in bold). In contrast, increasing capacity (ResNet-50) and more complex self-attention based aggregation functions (ViT) do not improve performance. Furthermore, DeepMIL-32/128 exhibit lower accuracy indicating that context beyond patch borders is crucial in classifying caries. The use of tooth/caries masks increases mean performance. In particular, EMIL + exhibits higher scores across metrics compared to a non-guided model and has a significantly higher F-score and AUROC than ResNet-18. These trends continue in the tooth-level data, although improvements of guided models are smaller. This is expected because the signal-to-noise ratio is higher than in bitewings. We also report results for the clinical labels (EHR GT) on which all models are based. The summary metrics (Bal. Acc., F-score) are higher for bitewings, possibly because errors at the tooth-level may still lead to true positives. Intriguingly, all tooth-level models show a higher accuracy/F-score than the EHR GT. This suggests that salient patterns are learned that clinicians missed (or did not report) and may result from the fact that mislabeled false positives have less weight in the loss function. Table 1 also reports average runtimes per iteration and peak memory usage for a realistic batch size of 16. EMIL is nearly as efficient as its underlying backbone, and up to 6.6× faster than Y-Net, while consuming up to 9.5× less memory at similar/better mean performance. , as well as DeepMIL, are sensitive to positive cases (rows 1-3) but not precise. Moreover, these methods do not ignore the negative class (row 4), and false negatives are accompanied by false positive visualizations (row 5). This is resolved in Y-Net and EMIL (Mỹ), and caries may be highlighted although the activation is too low to cross the classification threshold (e.g., row 5, column 8). Table 2 pixels of each map to 1, so that the total area equals the respective ground truth. When considering all confidences (IoU@0), Saliency and EMIL localize best. In the interactive settings, scores improve significantly in Y-Net and EMIL. We also conduct an experiment where only confident predictions (ŷ ≥ 0.95) are retained. Average localization performance increases most in EMIL models, by 12.4 and 30.9 percentage points, respectively. We presented two caries classifiers for bitewing and tooth images. The former indicates the general presence of caries in the dentition (with high PR AUC), while the latter may support diagnosis. One limitation is that training is performed with EHRs, which makes the labels error-prone. However, our dataset is much larger than related work, and relabeling at scale is impractical. Yet, the tooth-level model shows higher sensitivity than clinicians (62.68±4.37 vs. 44.27), suggesting that more lesions can be found and treated in practice. The heatmaps are a useful tool to see on what grounds a prediction is made, and to estimate caries severity. Furthermore, we showed that strong labels pointing at relevant regions improve classification and localization, opening up ways to integrate the user into the training process. Technically, our approach may serve further computer-aided diagnosis applications in radiology, where trust and the ability to integrate human knowledge are critical. Table 3 shows classification results and prevalence for different tooth and caries types for the guided EMIL model. In terms of tooth types, the model performs significantly worse for canines, which can be explained by the low prevalence in the data. A higher average F-score is observed for premolars compared to molars. The former may be easier to detect because they appear centrally in the bitewing and are unlikely to be partially cut out of the image. In secondary/recurrent caries, the average sensitivity is higher than in primary/initial caries. One possible reason for this result is that such lesions are adjacent to restorations which are radiopaque, sharply demarcated and easy to detect. In addition, secondary caries can spread more quickly because it no longer has to penetrate the hard enamel, but can quickly reach the softer interior of the tooth. In Table 4 , we revisit the baselines used in the main paper and make a conceptual comparison. EMIL and ViT (hybrid) extract patches from a CNN embedding, while DeepMIL extracts patches from the input image. The embedding approach has the advantage that the context is large due to the growing receptive field resulting from the sequence of convolutional layers. The MIL literature (Carbonneau et al., 2018; Amores, 2013) distinguishes between two different types of inputs for the classification function. The first type is the bag representation, which is calculated by aggregating instances using a MIL pooling operation (such as mean or attention). The second type used by EMIL is individual instances that are classified before aggregation. Standard CNNs instead use a global embedding, without distinguishing between instances and bags. There are also different assumptions about when a bag is considered positive (Foulds and Frank, 2010) . The most common is the standard assumption (any positive instance → bag positive; no positive instance → bag negative). In the weighted collective assumption (used by DeepMIL), all instances are considered in a weighted manner to infer the class of the bag. EMIL uses both the weighted collective and the threshold-based assumption, where a minimum number of patches (K min ) must be positive for the bag to be classified as positive. Standard CNNs are black boxes that require post-hoc attribution methods to give insights about their predictions (in ViT, attention maps can be visualized as well). The drawback of such methods is that they are not optimized for faithfulness (Adebayo et al., 2018; Rudin, 2019) and cannot explain the negative class (see Fig. 4, row 4) . Y-Net is interpretable through its decoder but requires segmentation labels and does not use the decoder output for classification. DeepMIL uses attention, but weights need to sum to 1, which is unintuitive for the negative class. EMIL uses both attention weights and patch probabilities to create faithful explanations for a prediction. Regarding interactive learning, standard CNNs are trained with classification labels but cannot learn from dense labels. Y-Net and EMIL are both able to learn from dense labels, but EMIL does it efficiently. Ilse et al. (2018) 2 is used. For ViT, we adapt the vit-pytorch repository 3 and found that a minimal hybrid version using a single transformer encoder block, with 8 heads (each 64-dimensional), works best. Both patch representations and inner MLP layers are 128-dimensional. No dropout is used, and all patch representations are averaged before the MLP head. We employ the same approach as in EMIL and create overlapping patches with a large kernel size of 5 and a stride of 1. For Y-Net, we adapt the residual U-Net implementation of the ResUnet repository 4 , where we add a fully-connected output layer to the bottleneck and learn the upsampling in the decoder. For saliency maps, occlusion sensitivity and Grad-CAM, we use the Captum library (Kokhlikyan et al., 2020) . The dataset consists of 38,174 bitewings, which corresponds to 316,388 teeth. To prepare the data, we make use of the following exclusion criteria. We use the Adam optimizer (Kingma and Ba, 2015) with a learning rate of 0.001, β 1 = 0.9, β 2 = 0.999 and no weight decay. In each fold, we train for 20-30 epochs depending on the training progress of the respective methods. We observed that training in the interactive settings is faster, which is expected as the segmentation masks guide the model to the salient patterns. A batch size of 32 is used for bitewings, and 128 for the tooth data. For Y-Net, we had to reduce the batch size to 16 and 32, respectively, due to memory constraints. Segmentation masks originally have the same dimension as the input image. In contrast, EMIL uses downscaled masks. To avoid that small carious lesions disappear due to downscaling, EMIL performs bilinear upsampling of the encoder output by a factor of 4 before extracting patches, resulting in spatial feature map resolutions of 64x84 for bitewing images and 48x48 for tooth images. Note that the primary task is classification, i.e., segmentation masks are used to guide training, but do not replace the classification label. A positive mask is only used if it corresponds to the class label; if the class label is negative, all elements of the mask are set to 0. Due to label noise, we do not use negative masks for the tooth-level task. Considering these filters, 35,683 masks (∼97%) remain for the bitewing data, and 43,178 masks (∼72% of all positive instances) remain for the tooth-level data. For more details on the performance of the caries segmentation model, see Cantu et al. (2020) . EMIL has two interesting hyperparameters, which we want to explain in more detail: K min and the patch size. Hyperparameter K min represents the minimum collective weight that must be assigned to the set of patches to be able to obtain a confident positive classification (i.e.,ŷ = 1). For simplicity, consider the case where attention weights can only take on values in {0, 1}. Then K min can be thought of as the minimum number of patches that must be attended to. If this constraint is violated, the denominator of Eq. 4 turns into a constant, and the network is incentivized (for the positive class) to attend to more patches by increasing w through the nominator. Note that the value ofŷ also depends onỹ, i.e., attended patches must be classified positively to obtain a high positive class score. Fig. 5 shows the effect of K min on the patch weight map. For increasing values of K min , sensitivity increases but precision decreases. If the value is too high, performance decreases because healthy tooth regions will be attended, which erroneously reduces disease probability (see, e.g., the first row of Fig. 5 ). When K min = 0, little attention is assigned to any patch because all possible class scores can be obtained independent of w . According to the standard MIL assumption (Dietterich et al., 1997; Foulds and Frank, 2010) , a single positive instance is sufficient to positively label a bag, therefore K min is set to 1 and must not be searched. The second hyperparameter is the patch size, which we set equal for both dimensions, H P = W P -we use H P in the following to denote both width and height. The patch size GT Tooth Figure 5 : Weight maps for increasing K min . Sensitivity increases, precision decreases. Figure 6 : Unnormalized weight maps for increasing patch sizes, with H P = W P . Sensitivity increases, precision decreases. controls the individual regions that are classified. If H P = H U , then a single global patch is considered and the training behavior is similar to a standard CNN. Fig. 6 shows the effect of H P on the heatmap for H S = 1. Attention weights of overlapping patches are summed and clipped at 1 to improve visualizations. One can observe that the sensitivity increases while precision decreases. In our experiments, the patch size had little impact on classification performance, but we prefer a small value for precise localization. Note that a small patch in the embedding space has a large receptive field and thus sufficient context to detect both small and larger lesions (Luo et al., 2016) . For the main experiments, we set K min = 1, H P = W P = 1 and H S = W S = 1. Sanity checks for saliency maps Towards robust interpretability with selfexplaining neural networks Multiple instance classification: Review, taxonomy and comparative study Diagnosis of interproximal caries lesions with deep convolutional neural network in digital bitewing radiographs Terabyte-scale deep multiple instance learning for classification and localization in pathology Detecting caries lesions of different radiographic extension on bitewings using deep learning Multiple instance learning: A survey of problem characteristics and applications Solving the multiple instance problem with axis-parallel rectangles An image is worth 16x16 words: Transformers for image recognition at scale A review of multi-instance learning assumptions. The knowledge engineering review Potential Biases in Machine Learning Algorithms Using Electronic Health Record Data Accurate screening of covid-19 using attention-based deep 3d multiple instance learning Deep residual learning for image recognition Distilling the knowledge in a neural network Interactive machine learning for health informatics: when do we need the human-in-the-loop? Attention-based deep multiple instance learning W Marcenes, and GBD 2015 Oral Health Collaborators. Global, regional, and national prevalence, incidence, and disability-adjusted life years for oral conditions for 195 countries, 1990-2015: a systematic analysis for the global burden of diseases, injuries, and risk factors Processing megapixel images with deep attention-sampling models Adam: A method for stochastic optimization Captum: A unified and generic model interpretability library for pytorch Classifying and segmenting microscopy images with deep multiple instance learning Example mining for incremental learning in medical imaging Understanding the effective receptive field in deep convolutional neural networks Dental caries classification system using deep learning based convolutional neural network Y-net: joint segmentation and classification for diagnosis of breast biopsy images Needles in haystacks: On classifying tiny objects in large images Stop explaining black box machine learning models for high stakes decisions and use interpretable models instead Radiographic caries detection: a systematic review and meta-analysis Grad-cam: Visual explanations from deep networks via gradient-based localization Deep inside convolutional networks: Visualising image classification models and saliency maps Detection of tooth caries in bitewing radiographs using deep learning Multiple instance learning for histopathological breast cancer image classification Genetic algorithms-based approach for dental caries detection using back propagation neural network Saliency is a possible red herring when diagnosing poor generalization Revisiting multiple instance neural networks Deep multiple instance learning for image classification and auto-annotation Deep learning of feature representation with multiple instance learning for medical image analysis Visualizing and understanding convolutional networks Visual interpretability for deep learning: a survey Deep multiple instance learning for automatic detection of diabetic retinopathy in retinal images This project has received funding by the German Ministry of Research and Education (BMBF) in the projects SyReal (project number 01|S21069A) and KI-LAB-ITSE (project number 01|S19066). When correct, both models detect lesions with few false positive visualizations. One reason for misclassification is low attention weights. For example, consider the first row of Fig. 8 , where the patch prediction heatmap weakly highlights both lesions, however little attention is assigned to them. Nevertheless, a dentist may use these maps to detect caries and mark lesions so that the network can learn to locate them explicitly.