Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images
Automating the process using Convolutional neural networks (using Python) to speed up blindness detection in patients before its too late
May 21 ·13min read
Table of contents
- Use of Deep learning to detect Blindness
- Evaluation metric (Quadratic weighted kappa)
- Image processing and analysis
- Implementation of an arXiv.org research Paper (Top 1% solution) using Multi Task Learning
- Other Transfer Learning Models
- Future work
- Link to github code and linkedin profile
- References used
1. Use of Deep learning to detect Blindness
This case study is based on the APTOS 2019 Blindness Detection based on the kaggle challenge here — https://www.kaggle.com/c/aptos2019-blindness-detection/overview .
Millions of people suffer from Diabetic retinopathy , the leading cause of blindness among working aged adults. Aravind Eye Hospital in India hopes to detect and prevent this disease among people living in rural areas where medical screening is difficult to conduct. Currently, the technicians travel to these rural areas to capture images and then rely on highly trained doctors to review the images and provide diagnosis.
The goal here is to scale their efforts through technology; to gain the ability to automatically screen images for disease and provide information on how severe the condition may be. We shall be achieving this by building a Convolutional neural network model that can automatically look at a patient’s eye image and estimate the severity of blindness in the patient. This process of automation can reduce a lot of time thereby screening the process of treating diabetic retinopathy at a large scale.
We are given 3200 eye images and their corresponding severity scale which is one of [0,1,2,3,4] . This data is to be used for training the model and prediction is to be done on the test data.
Sample eye images in the dataset are below —
2. Evaluation metric (Quadratic weighted kappa)
Quadratic weighted kappameasures the agreement between two ratings . This metric typically varies from 0 (random agreement between raters) to 1 (complete agreement between raters). In the event that there is less agreement between the raters than expected by chance, this metric may go below 0. The quadratic weighted kappa is calculated between the scores assigned by the human rater and the predicted scores .
Intuition of Cohen’s kappa
To understand this metric, we need to understand the concept of Cohen’s kappa ( wikipedia — link ). This metric accounts for the agreement that occurs by chance apart from the agreement that we observe from the Confusion matrix . We can understand this using a simple example —
Suppose we want to calculate cohen’s kappa from the agreement table below which is basically a confusion matrix. As we can see from the table below, out of total 50 observations, ‘A’ and ‘B’ raters agree on (20 yes+ 15 no) = 35 observations. So, observed agreement is P(o) = 35/50 = 0.7
We need to also find what proportion of ratings are due to random chance and not due to an actual agreement between ‘A’ and ‘B’. As we can see, ‘A’ said yes 25/50 = 0.5 times, ‘B’ said yes 30/50 = 0.6 times. So the probability that both of them would say ‘yes’ simultaneously at random is 0.5*0.6 = 0.3 . Similarly, the probability that both of them would say ‘no’ simultaneously at random is 0.5*0.4 = 0.2 . Thus, the probability of random agreement is 0.3 + 0.2 = 0.5 . Let us call this P(e).
So, Cohens kappa would be 0.4 (formula below).
The value we got is 0.4, we can interpret this using the table below. You can read more on interpretation about this metric here in this blog — https://towardsdatascience.com/inter-rater-agreement-kappas-69cd8b91ff75 .
A value of 0.4 would mean we have a fair/moderate agreement.
Intuition of Quadratic weights in ordinal classes — Quadratic weighted kappa
We can extend the same concept when it comes to multi-classes and also introducing the concept of weights here. The intuition behind introducing weights is that they are ordinal variables, which means an agreement between classes ‘1’ and ‘2’ is better than classes ‘1’ and ‘3’, because they are nearer in the ordinal scale. In our blindness severity case, output is ordinal (which is blindness severity, 0 representing no blindness, 4 representing highest). To account for this, Weights in quadratic scale are introduced. The more closer the ordinal classes, the higher are their weights. Here is the link to the blog post that explains this concept very well — https://medium.com/x8-the-ai-community/kappa-coefficient-for-dummies-84d98b6f13ee .
As we can see, R1 and R4 have weight of 0.44 (because they are 3 ordinal classes apart compared to R1 and R2 which has weight of 0.94 because they are closer). These weights are multipled by the corresponding probabilities in calculations of probabilities of observed and random chances.
3. Image processing and analysis
Class distribution of output variable
As we can see, there is class imbalance in the training data-set with most cases having value of ‘0’ and least in ‘1’ and ‘3’ classes.
Visualize Eye images
These are basically eye retina images taken using fundus photography . As we can see, images contain artifacts, some of them are out of focus, underexposed, or overexposed etc. Also, some of the images have low brightness and low lightning conditions thus making it difficult to assess the difference between the images.
Below is a sample retina image of an eye that would contain diabetic retinopathy . Source links — kaggle kernel link . Original source — https://www.eyeops.com/
As we can see, we are trying to spot those hemorrhages/exudates etc. in the higher classes which have DR.
Image processing
To adjust for the images and make more clearer images so as to enable the model to learn features more effectively, we will carry out some image processing techniques using OpenCV library in python ( cv2 ).
We can apply Gaussian blur to bring out distinctive features in the images. In Gaussian Blur operation, the image is convolved with a Gaussian filter which is a low-pass filter that removes the high-frequency components .
This brilliantly written kernel ( link ) introduces the idea of circular cropping from gray scale images. Implementing the same in the code section below :-
As we can see, we are much more clearly able to see the distinctive patterns in the imgaes now.Here are the image processing applied on 15 image samples.
Multi processing and resizing images to save in directory
After applying the above operations, we need to save the new images in the folder that can be used later. The original images are each of around 3 MB, the entire images folder occupies 20 GB space. We can reduce this by resizing images. Using multi core threading / Multi processing, we can achieve this task under a minute. We can use ThreadPool with 6 cores (since I have 8 core CPU) to achieve this and an image size of (512x512) to achive this ( IMG_SIZE ).
TSNE visualization
In order to understand if the images are seperable in the respective classes (blindness severity), We can first use TSNE to visualize the same in 2-dimensions. We can first convert the RGB scale images to Gray scale images and then flatten the images out to generate a vector representation which can be used as a feature representation of that image.
RGB (256x256x3)-> GRAY SCALE(256x256) -> FLATTEN(65536)
Perplexity is the hyperparameter that needs to be tuned to get good results. After the iterations, We can use TSNE plot for perplexity = 40.
As we can see from the above plot, class ‘0’ is somewhat seperable from other classes whereas the distinction between the other classes is still quite vague.
Image augmentations
This is one of the most used procedure to generate robustness in the data by creating additional images from the dataset to make it generalize well on new data with rotation flips, cropping, padding etc. using the keras ImageDataGenerator class
Above code generates sample images obtained after applying the augmentations —
4. Implementation of an arXiv.org research Paper (Top 1% solution) using Multi Task Learning
One of the Research paper’s found on arXiv.org — research paper link gets 54th rank out of 2943 (Top 1%) involves a detailed approach of Multi Task learning to solve this problem statement. The below is an attempt to replicate the Research paper code using Python. No code reference on github to this research paper was found. I have summarized the Research paper in a Research Paper (summary doc).pdf file — github link . Summarizing pointers and Code Implementation from the research paper below —
Datasets Used
The actual research paper uses multiple datasets from other sources like —
- 35,216 images from Diabetic Retinopathy,2015 challenge — https://www.kaggle.com/c/diabetic-retinopathy-detection/overview/timeline
- Indian Diabetic Retinopathy Image Dataset (IDRiD) (Sahasrabuddhe and Meriaudeau, 2018) = 413 images used
- MESSIDOR dataset (Google Brain,2018) dataset
- The full dataset consists of 18590 fundus photographs, which are divided into 3662 training, 1928 validation, and 13000 testing images by organizers of Kaggle competition
However, due to non availability of all datasets easily, We could use only the existing APTOS 2019 dataset for this task.
Image Preprocessing and Augmentations
Multiple image preprocessing techniques like Image resizing, cropping were used to bring out distinctive features from eye images (as discussed in the sections above).
Image Augmentations used were — optical distortion, grid distortion, piecewise affine transform, horizontal flip, vertical flip, random rotation, random shift,random scale, a shift of RGB values, random brightness and contrast, additive Gaussian noise, blur, sharpening, embossing, random gamma, and cutout etc.
However, we could use alternative Image Augmentations similar to above (reference used — link ).
Model Architecture Used
As we can see below, the Research paper uses a Multi Task learning Model (it parallely does training for Regression, Classification, Ordinal Regression). This way it can use single Model and since first layers would anyway learn similar features, this architecture is implemented to reduce training time (instead of training 3 seperate models). For the Encoder part, we could use any Existing CNN architecture — ResNet50, EfficientNetB4, EfficientNetB5 ( and ensemble these).
The code implementation for the above architecture is below (in keras) —
Creating Multi Output — Custom Image Data Generator
We need to create a custom ImageDataGenerator function since we have 3 outputs (reference — link ). Note that this function yield 3 outputs (regression, classification, ordinal regression)Code below —
Note that the Ordinal regression Encoding is done as following for each of the classes (0,1,2,3,4). MultiLabelBinarizer in sklearn achieves this task as shown in the Code gist above.
Stage 1 (Pre Training)
This involves making all layers trainable and using existing Imagenet Weights as weight Initializers for ResNet50 Encoder. Model is trained only upto the 3 Output Heads. Model is compiled with the Loss Functions below (20 epochs with SGD optimizer, Cosine Decay Scheduler) —
model.compile( optimizer = optimizers.SGD(lr=LEARNING_RATE),loss={'regression_output': 'mean_absolute_error', 'classification_output':'categorical_crossentropy', 'ordinal_regression_output' : 'binary_crossentropy'},metrics = ['accuracy'])
Stage 2 (Main Training)
2 things are changed in this stage.
- First, the loss functions are changed from cross Entropy to Focal Loss . You can read more about Focal loss here — https://medium.com/adventures-with-deep-learning/focal-loss-demystified-c529277052de . In brief, Focal loss does a better job in handling class imbalance as is the case here in our task.
Reference Code for focal loss — https://github.com/umbertogriffo/focal-loss-keras
model.compile( optimizer = optimizers.Adam(lr=LEARNING_RATE),loss={ 'regression_output': mean_squared_error, 'classification_output': categorical_focal_loss(alpha=.25,gamma=2), 'ordinal_regression_output' : binary_focal_loss(alpha=.25,gamma=2) },metrics = ['accuracy'])
2. Second, Training in this 2nd stage is further done in 2 sub-stages. First sub-stage includes freezing all the Encoder layers in the Model network. This is done to warm up the weights (Transfer learning on small datasets using Imagenet weight intializations). The 2nd sub-stage involves unfreezing and training all Layers.
# SUB STAGE 1 (2nd Stage) - Freeze Encoder layersfor layer in model.layers:
layer.trainable = False
for i in range(-14,0):
model.layers[i].trainable = True# SUB STAGE 2(2nd Stage) - Unfreeze All layersfor layer in model.layers:
layer.trainable = True
Graphs are shown below —
Stage 3 (Post Training)
This involves getting the outputs from the 3 heads (Classification, Regression, Ordinal Regression) and passing this to a single Dense Neuron (Linear Activation) to minimize Mean Squared Error (50 epochs)
train_preds = model.predict_generator( complete_generator, steps=STEP_SIZE_COMPLETE,verbose = 1 )train_output_regression = np.array(train_preds[0]).reshape(-1,1) train_output_classification = np.array(np.argmax(train_preds[1],axis = -1)).reshape(-1,1) train_output_ordinal_regression = np.array(np.sum(train_preds[2],axis = -1)).reshape(-1,1)X_train = np.hstack(( train_output_regression, train_output_classification, train_output_ordinal_regression))model_post = Sequential() model_post.add(Dense(1, activation='linear', input_shape=(3,)))model_post.compile( optimizer=optimizers.SGD(lr=LEARNING_RATE), loss='mean_squared_error', metrics=['mean_squared_error'])
Model Evaluation on Test Data
Since we have regression Output, we can do nearest Integer Rounding to get the Final Class Label.
The Final Quadratic Weighted kappa score that we obtain on Test Data is 0.704 (which indicates a Substantial Agreement between Model Predictions and Human Raters).
The Normalized Confusion Matrix obtained on Test Data (The code for matplotlib has been referenced from — link )
5. Other Transfer Learning Models
The common Transfer Learning method used when approaching small datasets (with no similarity to ImageNet dataset) is to first use an existing ImageNet weights as Initializers (freezing the first few layers) and then re-training the model.
We could use similar implementation. A simple ResNet50 architecture would give good results when used this way (reference — link )
We can train the above model for 2–5 epochs (only last 5 layers are trainable which are basically the layers after ResNet50 ).
Then, we can make all layers Trainable and train the entire model.
for layer in model.layers:
layer.trainable = True
As we can see, only within 20 Epochs, we get good Accuracy Score — Close to 92% on Validation dataset.
This Model gives Quadratic weighted kappa of 0.83 on Test data (which is a good agreement score).
Similarly, we can do similar Implementations with other keras.applications as mentioned in this page — link . Below are the results obtained when implemented on some of them (only Model architecture is changed, other parameters are kept same).
6. Future Work
Experimentation with Optimizers
As we can see, the Training Loss is failing to reduce in the Main Training stage. One of the ways we could deal with this is by changing the optimizer from Adam to Rectified Adam optimizer . (More on RAdam here — https://www.pyimagesearch.com/2019/10/07/is-rectified-adam-actually-better-than-adam/ ). Also, we could experiment with SGD (Momentum) optimizers along with Weight Decay methods , more Learning Rate Schedulers to check for model performance improvements.
Experimentation with Regularization methods
Additionally, we could also use some more regularization on the Models to help train the model better — some techniques could involve Label smoothing (to add a noise to target labels) — this would allow the model to generalize better and prevent over-fitting . Also, we could experiment with L2 Regularization to improve model generalization ability.
Experimentation with Ensembling and K Fold Cross Validation
The research paper also mentions usage of Ensembling across various architectures — EfficientNetB4, 5EfficientNetB5, SE-ResNeXt50 etc and using Stratified Cross validation (5 fold) to improve model performance and generalization ability.
7. Link to Github Code and linkedin profile
All the codes are present in my Github repository — link . You can connect me on my Linkedin profile — link if you would like to discuss on further experimentation on this. You can also reach out to me on debayanmitra1993@gmail.com .
6. References used
Research paper — https://arxiv.org/pdf/2003.02261.pdf
以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网
猜你喜欢:本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。
Reality Is Broken
Jane McGonigal / Penguin Press HC, The / 2011-1-20 / USD 26.95
Visionary game designer Jane McGonigal reveals how we can harness the power of games to solve real-world problems and boost global happiness. More than 174 million Americans are gamers, and......一起来看看 《Reality Is Broken》 这本书的介绍吧!