Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

栏目: IT技术 · 发布时间: 4年前

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

Automating the process using Convolutional neural networks (using Python) to speed up blindness detection in patients before its too late

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

May 21 ·13min read

Table of contents

  1. Use of Deep learning to detect Blindness
  2. Evaluation metric (Quadratic weighted kappa)
  3. Image processing and analysis
  4. Implementation of an arXiv.org research Paper (Top 1% solution) using Multi Task Learning
  5. Other Transfer Learning Models
  6. Future work
  7. Link to github code and linkedin profile
  8. References used

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

image source — http://images2.fanpop.com/image/photos/10300000/Cas-Blindness-castiel-10370846-640-352.jpg

1. Use of Deep learning to detect Blindness

This case study is based on the APTOS 2019 Blindness Detection based on the kaggle challenge here — https://www.kaggle.com/c/aptos2019-blindness-detection/overview .

Millions of people suffer from Diabetic retinopathy , the leading cause of blindness among working aged adults. Aravind Eye Hospital in India hopes to detect and prevent this disease among people living in rural areas where medical screening is difficult to conduct. Currently, the technicians travel to these rural areas to capture images and then rely on highly trained doctors to review the images and provide diagnosis.

The goal here is to scale their efforts through technology; to gain the ability to automatically screen images for disease and provide information on how severe the condition may be. We shall be achieving this by building a Convolutional neural network model that can automatically look at a patient’s eye image and estimate the severity of blindness in the patient. This process of automation can reduce a lot of time thereby screening the process of treating diabetic retinopathy at a large scale.

We are given 3200 eye images and their corresponding severity scale which is one of [0,1,2,3,4] . This data is to be used for training the model and prediction is to be done on the test data.

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images
Blindness severity scale (5 classes)

Sample eye images in the dataset are below —

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

Eye images corresponding to each Blindness class severity (0–4)

2. Evaluation metric (Quadratic weighted kappa)

Quadratic weighted kappameasures the agreement between two ratings . This metric typically varies from 0 (random agreement between raters) to 1 (complete agreement between raters). In the event that there is less agreement between the raters than expected by chance, this metric may go below 0. The quadratic weighted kappa is calculated between the scores assigned by the human rater and the predicted scores .

Intuition of Cohen’s kappa

To understand this metric, we need to understand the concept of Cohen’s kappa ( wikipedia — link ). This metric accounts for the agreement that occurs by chance apart from the agreement that we observe from the Confusion matrix . We can understand this using a simple example —

Suppose we want to calculate cohen’s kappa from the agreement table below which is basically a confusion matrix. As we can see from the table below, out of total 50 observations, ‘A’ and ‘B’ raters agree on (20 yes+ 15 no) = 35 observations. So, observed agreement is P(o) = 35/50 = 0.7

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images
source — wikipedia

We need to also find what proportion of ratings are due to random chance and not due to an actual agreement between ‘A’ and ‘B’. As we can see, ‘A’ said yes 25/50 = 0.5 times, ‘B’ said yes 30/50 = 0.6 times. So the probability that both of them would say ‘yes’ simultaneously at random is 0.5*0.6 = 0.3 . Similarly, the probability that both of them would say ‘no’ simultaneously at random is 0.5*0.4 = 0.2 . Thus, the probability of random agreement is 0.3 + 0.2 = 0.5 . Let us call this P(e).

So, Cohens kappa would be 0.4 (formula below).

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images
source — wikipedia, cohens kappa formula

The value we got is 0.4, we can interpret this using the table below. You can read more on interpretation about this metric here in this blog — https://towardsdatascience.com/inter-rater-agreement-kappas-69cd8b91ff75 .

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

source — https://towardsdatascience.com/inter-rater-agreement-kappas-69cd8b91ff75

A value of 0.4 would mean we have a fair/moderate agreement.

Intuition of Quadratic weights in ordinal classes — Quadratic weighted kappa

We can extend the same concept when it comes to multi-classes and also introducing the concept of weights here. The intuition behind introducing weights is that they are ordinal variables, which means an agreement between classes ‘1’ and ‘2’ is better than classes ‘1’ and ‘3’, because they are nearer in the ordinal scale. In our blindness severity case, output is ordinal (which is blindness severity, 0 representing no blindness, 4 representing highest). To account for this, Weights in quadratic scale are introduced. The more closer the ordinal classes, the higher are their weights. Here is the link to the blog post that explains this concept very well — https://medium.com/x8-the-ai-community/kappa-coefficient-for-dummies-84d98b6f13ee .

As we can see, R1 and R4 have weight of 0.44 (because they are 3 ordinal classes apart compared to R1 and R2 which has weight of 0.94 because they are closer). These weights are multipled by the corresponding probabilities in calculations of probabilities of observed and random chances.

3. Image processing and analysis

Class distribution of output variable

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

Class imbalance — distribution

As we can see, there is class imbalance in the training data-set with most cases having value of ‘0’ and least in ‘1’ and ‘3’ classes.

Visualize Eye images

Code snippet for generating visualizations

These are basically eye retina images taken using fundus photography . As we can see, images contain artifacts, some of them are out of focus, underexposed, or overexposed etc. Also, some of the images have low brightness and low lightning conditions thus making it difficult to assess the difference between the images.

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

Sample 15 eye images

Below is a sample retina image of an eye that would contain diabetic retinopathy . Source links — kaggle kernel link . Original source — https://www.eyeops.com/

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

img source — https://www.kaggle.com/ratthachat/aptos-eye-preprocessing-in-diabetic-retinopathy , https://www.eyeops.com/

As we can see, we are trying to spot those hemorrhages/exudates etc. in the higher classes which have DR.

Image processing

To adjust for the images and make more clearer images so as to enable the model to learn features more effectively, we will carry out some image processing techniques using OpenCV library in python ( cv2 ).

We can apply Gaussian blur to bring out distinctive features in the images. In Gaussian Blur operation, the image is convolved with a Gaussian filter which is a low-pass filter that removes the high-frequency components .

Code snippet — Gaussian blur

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

Before/After gaussian blur

This brilliantly written kernel ( link ) introduces the idea of circular cropping from gray scale images. Implementing the same in the code section below :-

Code snippet — Circular crop

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

Before/after (Blur + crop)

As we can see, we are much more clearly able to see the distinctive patterns in the imgaes now.Here are the image processing applied on 15 image samples.

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

after image processing

Multi processing and resizing images to save in directory

After applying the above operations, we need to save the new images in the folder that can be used later. The original images are each of around 3 MB, the entire images folder occupies 20 GB space. We can reduce this by resizing images. Using multi core threading / Multi processing, we can achieve this task under a minute. We can use ThreadPool with 6 cores (since I have 8 core CPU) to achieve this and an image size of (512x512) to achive this ( IMG_SIZE ).

Code snippet — multi processing with ThreadPool

TSNE visualization

In order to understand if the images are seperable in the respective classes (blindness severity), We can first use TSNE to visualize the same in 2-dimensions. We can first convert the RGB scale images to Gray scale images and then flatten the images out to generate a vector representation which can be used as a feature representation of that image.

RGB (256x256x3)-> GRAY SCALE(256x256) -> FLATTEN(65536)

Code snippet — TSNE

Perplexity is the hyperparameter that needs to be tuned to get good results. After the iterations, We can use TSNE plot for perplexity = 40.

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

TSNE plot, perplexity = 40

As we can see from the above plot, class ‘0’ is somewhat seperable from other classes whereas the distinction between the other classes is still quite vague.

Image augmentations

This is one of the most used procedure to generate robustness in the data by creating additional images from the dataset to make it generalize well on new data with rotation flips, cropping, padding etc. using the keras ImageDataGenerator class

Above code generates sample images obtained after applying the augmentations —

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

Image augmentations (sample)

4. Implementation of an arXiv.org research Paper (Top 1% solution) using Multi Task Learning

One of the Research paper’s found on arXiv.org — research paper link gets 54th rank out of 2943 (Top 1%) involves a detailed approach of Multi Task learning to solve this problem statement. The below is an attempt to replicate the Research paper code using Python. No code reference on github to this research paper was found. I have summarized the Research paper in a Research Paper (summary doc).pdf file — github link . Summarizing pointers and Code Implementation from the research paper below —

Datasets Used

The actual research paper uses multiple datasets from other sources like —

  1. 35,216 images from Diabetic Retinopathy,2015 challenge — https://www.kaggle.com/c/diabetic-retinopathy-detection/overview/timeline
  2. Indian Diabetic Retinopathy Image Dataset (IDRiD) (Sahasrabuddhe and Meriaudeau, 2018) = 413 images used
  3. MESSIDOR dataset (Google Brain,2018) dataset
  4. The full dataset consists of 18590 fundus photographs, which are divided into 3662 training, 1928 validation, and 13000 testing images by organizers of Kaggle competition

However, due to non availability of all datasets easily, We could use only the existing APTOS 2019 dataset for this task.

Image Preprocessing and Augmentations

Multiple image preprocessing techniques like Image resizing, cropping were used to bring out distinctive features from eye images (as discussed in the sections above).

Image Augmentations used were — optical distortion, grid distortion, piecewise affine transform, horizontal flip, vertical flip, random rotation, random shift,random scale, a shift of RGB values, random brightness and contrast, additive Gaussian noise, blur, sharpening, embossing, random gamma, and cutout etc.

However, we could use alternative Image Augmentations similar to above (reference used — link ).

Model Architecture Used

As we can see below, the Research paper uses a Multi Task learning Model (it parallely does training for Regression, Classification, Ordinal Regression). This way it can use single Model and since first layers would anyway learn similar features, this architecture is implemented to reduce training time (instead of training 3 seperate models). For the Encoder part, we could use any Existing CNN architecture — ResNet50, EfficientNetB4, EfficientNetB5 ( and ensemble these).

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

source — https://arxiv.org/pdf/2003.02261.pdf

The code implementation for the above architecture is below (in keras) —

Creating Multi Output — Custom Image Data Generator

We need to create a custom ImageDataGenerator function since we have 3 outputs (reference — link ). Note that this function yield 3 outputs (regression, classification, ordinal regression)Code below —

Note that the Ordinal regression Encoding is done as following for each of the classes (0,1,2,3,4). MultiLabelBinarizer in sklearn achieves this task as shown in the Code gist above.

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images
source — https://www.kaggle.com/c/aptos2019-blindness-detection/discussion/98239

Stage 1 (Pre Training)

This involves making all layers trainable and using existing Imagenet Weights as weight Initializers for ResNet50 Encoder. Model is trained only upto the 3 Output Heads. Model is compiled with the Loss Functions below (20 epochs with SGD optimizer, Cosine Decay Scheduler) —

model.compile(
optimizer = optimizers.SGD(lr=LEARNING_RATE),loss={'regression_output': 'mean_absolute_error',
'classification_output':'categorical_crossentropy',
'ordinal_regression_output' : 'binary_crossentropy'},metrics = ['accuracy'])

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

Epoch (vs) Loss Graphs for Pre Training Stage

Stage 2 (Main Training)

2 things are changed in this stage.

  1. First, the loss functions are changed from cross Entropy to Focal Loss . You can read more about Focal loss here — https://medium.com/adventures-with-deep-learning/focal-loss-demystified-c529277052de . In brief, Focal loss does a better job in handling class imbalance as is the case here in our task.
Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images
source — https://medium.com/adventures-with-deep-learning/focal-loss-demystified-c529277052de

Reference Code for focal loss — https://github.com/umbertogriffo/focal-loss-keras

model.compile(
optimizer = optimizers.Adam(lr=LEARNING_RATE),loss={
'regression_output': mean_squared_error, 
'classification_output': categorical_focal_loss(alpha=.25,gamma=2),
'ordinal_regression_output' : binary_focal_loss(alpha=.25,gamma=2)
},metrics = ['accuracy'])

2. Second, Training in this 2nd stage is further done in 2 sub-stages. First sub-stage includes freezing all the Encoder layers in the Model network. This is done to warm up the weights (Transfer learning on small datasets using Imagenet weight intializations). The 2nd sub-stage involves unfreezing and training all Layers.

# SUB STAGE 1 (2nd Stage) - Freeze Encoder layersfor layer in model.layers:
layer.trainable = False
for i in range(-14,0):
model.layers[i].trainable = True
# SUB STAGE 2(2nd Stage) - Unfreeze All layersfor layer in model.layers:
layer.trainable = True

Graphs are shown below —

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

Epoch (vs) Loss Graphs for Main Training Stage

Stage 3 (Post Training)

This involves getting the outputs from the 3 heads (Classification, Regression, Ordinal Regression) and passing this to a single Dense Neuron (Linear Activation) to minimize Mean Squared Error (50 epochs)

train_preds = model.predict_generator(
complete_generator, steps=STEP_SIZE_COMPLETE,verbose = 1
)train_output_regression = np.array(train_preds[0]).reshape(-1,1)
train_output_classification = np.array(np.argmax(train_preds[1],axis = -1)).reshape(-1,1)
train_output_ordinal_regression = np.array(np.sum(train_preds[2],axis = -1)).reshape(-1,1)X_train = np.hstack((
train_output_regression,
train_output_classification,
train_output_ordinal_regression))model_post = Sequential()
model_post.add(Dense(1, activation='linear', input_shape=(3,)))model_post.compile(
optimizer=optimizers.SGD(lr=LEARNING_RATE), loss='mean_squared_error',
metrics=['mean_squared_error'])

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

Post Training — Epochs (vs) Loss

Model Evaluation on Test Data

Since we have regression Output, we can do nearest Integer Rounding to get the Final Class Label.

The Final Quadratic Weighted kappa score that we obtain on Test Data is 0.704 (which indicates a Substantial Agreement between Model Predictions and Human Raters).

The Normalized Confusion Matrix obtained on Test Data (The code for matplotlib has been referenced from — link )

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

Test Data — Normalized Confusion Matrix

5. Other Transfer Learning Models

The common Transfer Learning method used when approaching small datasets (with no similarity to ImageNet dataset) is to first use an existing ImageNet weights as Initializers (freezing the first few layers) and then re-training the model.

We could use similar implementation. A simple ResNet50 architecture would give good results when used this way (reference — link )

We can train the above model for 2–5 epochs (only last 5 layers are trainable which are basically the layers after ResNet50 ).

Then, we can make all layers Trainable and train the entire model.

for layer in model.layers:
layer.trainable = True

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

ResNet50 (Transfer Learning) — Epoch (vs) Loss

As we can see, only within 20 Epochs, we get good Accuracy Score — Close to 92% on Validation dataset.

This Model gives Quadratic weighted kappa of 0.83 on Test data (which is a good agreement score).

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images

Normalized Confusion Matrix (Test data)

Similarly, we can do similar Implementations with other keras.applications as mentioned in this page — link . Below are the results obtained when implemented on some of them (only Model architecture is changed, other parameters are kept same).

Blindness detection (Diabetic retinopathy) using Deep learning on Eye retina images
Model Performance — Summary

6. Future Work

Experimentation with Optimizers

As we can see, the Training Loss is failing to reduce in the Main Training stage. One of the ways we could deal with this is by changing the optimizer from Adam to Rectified Adam optimizer . (More on RAdam here — https://www.pyimagesearch.com/2019/10/07/is-rectified-adam-actually-better-than-adam/ ). Also, we could experiment with SGD (Momentum) optimizers along with Weight Decay methods , more Learning Rate Schedulers to check for model performance improvements.

Experimentation with Regularization methods

Additionally, we could also use some more regularization on the Models to help train the model better — some techniques could involve Label smoothing (to add a noise to target labels) — this would allow the model to generalize better and prevent over-fitting . Also, we could experiment with L2 Regularization to improve model generalization ability.

Experimentation with Ensembling and K Fold Cross Validation

The research paper also mentions usage of Ensembling across various architectures — EfficientNetB4, 5EfficientNetB5, SE-ResNeXt50 etc and using Stratified Cross validation (5 fold) to improve model performance and generalization ability.

7. Link to Github Code and linkedin profile

All the codes are present in my Github repository — link . You can connect me on my Linkedin profile — link if you would like to discuss on further experimentation on this. You can also reach out to me on debayanmitra1993@gmail.com .


以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Reality Is Broken

Reality Is Broken

Jane McGonigal / Penguin Press HC, The / 2011-1-20 / USD 26.95

Visionary game designer Jane McGonigal reveals how we can harness the power of games to solve real-world problems and boost global happiness. More than 174 million Americans are gamers, and......一起来看看 《Reality Is Broken》 这本书的介绍吧!

HTML 压缩/解压工具
HTML 压缩/解压工具

在线压缩/解压 HTML 代码

URL 编码/解码
URL 编码/解码

URL 编码/解码

XML 在线格式化
XML 在线格式化

在线 XML 格式化压缩工具