内容简介:Now that the class imbalance problem is dealt with in the next part we will look into image normalization and data augmentation using Keras and TensorFlow.
Deep Learning in Healthcare — X-Ray Imaging (Part 4-The Class Imbalance problem)
This is part 4 of the application of Deep learning on X-Ray imaging. Here the focus will be on various ways to tackle the class imbalance problem.
Jul 5 ·11min read
As we saw in the previous part — Part 3 — ( https://towardsdatascience.com/deep-learning-in-healthcare-x-ray-imaging-part-3-analyzing-images-using-python-915a98fbf14c ), the chest x-ray dataset has an imbalance of images. This is the bar chart of the images per class that we had seen in the previous part.
In medical imaging datasets, this is a very common problem. Since most often the data is collected from various different sources, and not all diseases are as prevalent as others, so the datasets are imbalanced more often than not.
So what is the problem if we train the neural network on an imbalanced dataset? The answer is that the network tends to learn more from the classes with more images than the ones with fewer images. That is, in this case, the model might predict more images to be ‘Bacterial Pneumonia’, even though the images might be from the other two classes, and that is an undesirable outcome when dealing with medical images.
Also, it should be noted, while dealing with medical images, the final accuracy (both train accuracy or validation accuracy) of the model is not the right parameter to base the model’s performance on. Because, even if the model is performing poorly on a particular class, but performing well on the class with maximum images, the accuracy would still be high. In reality, we want the model to perform well in all the classes. Thus, there are other parameters, such as sensitivity(Recall/True Positive Rate (TPR)), specificity(True Negative Rate(TNR), Precision or Positive Predicted Value (PPV), and F-scores, which should be considered to analyze the performance of a trained model. We will discuss these in detail in a later part, where we discuss the confusion matrix.
It is also a must to maintain a separate set of images, on which the model is neither trained nor validated, so as to check how the model performs on images that it has never seen before. This is also a compulsory step to analyze the performance of the model.
Various ways to tackle class imbalance:
There are various ways to tackle the class imbalance problem. The best method is to collect more images for the minority classes. But that is not possible in certain situations. In that case, commonly these 3 methods can be beneficial: a. Weighted Loss b. Undersampling c. Oversampling
We will go through each of these methods in details:
- Updating the loss function — Weighted Loss
Suppose we are using Binary Cross-Entropy loss function . The loss function looks like this -
L(X,y) = - log P(Y =1 |X) if y =1 and -log P(Y=0 |X) if y=0
This measures the output of a classification model whose output is between zero and one. (This loss function only works if we are doing a binary classification problem. For multiple classes, we use Categorical Cross-Entropy loss or Sparse Categorical Cross-Entropy loss. We will discuss basic loss functions in a later part).
Example — If the label of an image is 1, and the neural network algorithm predicts the probability that the label is 1 is 0.2.
Let's apply the loss function to compute the loss for this example. Notice that we are interested in the label 1. So, we are going to use the first part of the loss function L. The loss L is going to be -
L =-log 0.2 = 0.70
So this is the loss the algorithm gets for this example.
For another image whose label is 0, if the algorithm predicts that the probability of the image to be label 0 is 0.7, then we use the second part of the loss function, but we cannot really use it directly. Rather we use a different approach. We know the maximum probability can be 1, so we calculate the probability of the label is 1.
In this case, L = -log (1–0.7) =-log (0.3) = 0.52
Now let's look at multiple examples, with class imbalance.
In Figure 1, we see there are a total of 10 images, but 8 of those belong to class label 1, and only two belong to class label 0. Hence this is a classic class imbalance problem. Assuming all were predicted with a probability 0.5,
Loss L for label 1 = -log(0.5) = 0.3,
Loss L for label 0 = -log(1–0.5) = -log(0.5) = 0.3
So, the total loss for label 1 = 0.3 x 8 = 2.4
whereas, the total loss for label 0 = 0.3 x 2 = 0.6
So, most of the contributions to the loss is coming from class with label 1. So the algorithm when updating weights will prefer to update weights of label 1 images much more, than weights of images with label 0. This does not produce a very good classifier, and this is the Class Imbalance Problem .
The solution to the class imbalance problem is to modify the loss function, to weigh the 1 and 0 classes differently .
w1 is the weights we assign to label 1 examples, and w0 to label 0 examples. New Loss Function,
L = w1 x -log(Y =1 |X) if y =1 ,and,
L = w0 x -log P(Y=0 |X) if y=0
We want to give more weights to classes with fewer images than the classes with more images. So in this case we give class 1 which as 8 examples a weight of 2/10 = 0.2, and class 0 which as 2 examples a weight of 8/10 = 0.8.
Generally, the weights are calculated using the formula below,
w1 = number of images with label 0/total number images = 2/10
w0 = number of images with label 1/total number of images = 8/10
Below is the updated table of loss, by using weighted loss.
So for the new calculations, we just multiply the losses with the respective weights of the classes. Now if we calculate the total loss,
The total loss for label 1 = 0.06 x 8 = 0.48
The total loss for label 0 = 0.24 x 2 = 0.48
Now both the classes have the same total loss. So even though both classes have a different number of images, the algorithm will now treat both the classes equally, and the classifier will correctly classify images of classes with even very few images.
2. Downsampling
Downsampling is the process of removing images from the class with most images to make it comparable with the classes with lower images.
For example, in the pneumonia classification problem, we see that there are 2530 bacterial pneumonia images compared to 1341 normal and 1337 viral pneumonia images. So we can just remove around 1200 images from the bacterial pneumonia class so that all the classes have a similar number of images.
This is possible for datasets that have a lot of images belonging to each class, and removing a few images will not hurt the performance of the neural network.
3. Oversampling
Oversampling is the process of adding more images to minority classes so as to make the number of images in minority classes similar to those in the majority classes.
This can be done by simply duplicating the images in the minority classes. Directly copying the same image twice, can cause the network to overfit. So to reduce overfitting we can use some artificial data augmentation to create more images for the minority classes. (This too does cause some overfitting, but is a much better technique than directly copying the original images two-three times)
This is the technique that we used in the pneumonia classification task, and the network worked quite well.
Next, we look at the python code to generate an artificial dataset.
import numpy as np
import pandas as pd
import cv2 as cv
import matplotlib.pyplot as plt
import os
import random
from sklearn.model_selection import train_test_split
We have seen all the libraries before, except sklearn.
sklearn — Scikit-learn (also known as sklearn) is a machine learning library for python. It contains all famous machine learning algorithms such as classification, regression, support vector machines, random forests, etc. It is also a very important library for machine learning data pre-processing.
image_size = 256
labels = ['1_NORMAL', '2_BACTERIA','3_VIRUS']
def create_training_data(paths):
images = []
for label in labels:
dir = os.path.join(paths,label)
class_num = labels.index(label)
for image in os.listdir(dir):
image_read = cv.imread(os.path.join(dir,image))
image_resized = cv.resize(image_read,(image_size,image_size),cv.IMREAD_GRAYSCALE)
images.append([image_resized,class_num])
return np.array(images)train = create_training_data('D:/Kaggle datasets/chest_xray_tf/train')X = []
y = []
for feature, label in train:
X.append(feature)
y.append(label)
X= np.array(X)
y = np.array(y)
y = np.expand_dims(y, axis=1)
The above code calls the training dataset and loads the images in X and the labels in y. Details already mentioned in Part 3 — ( https://towardsdatascience.com/deep-learning-in-healthcare-x-ray-imaging-part-3-analyzing-images-using-python-915a98fbf14c ).
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,random_state = 32, stratify=y)
Since we have only train and validation data, and no test data, so we create the test data using train_test_split from sklearn. It is used to split the entire data into train and test images and labels. We assign 20% of the entire data to test set, and hence set ‘test_size = 0.2’, and random_state shuffles the data the first time but then keeps them constant from the next run and is used to not shuffle the images every time we run train_test_split, stratify is important to be mentioned here, as the data is imbalanced, as stratify makes sure that there is an equal split of images of each class in the train and test sets.
Important Note — Oversampling should be done on train data, and not on test data as if test data contains artificially generated images, the classifier results we will see would not be a proper interpretation of how much the network actually learned. So, the better method is to first split the train and test data and then oversample only the training data.
# checking the number of images of each class
a = 0
b = 0
c = 0
for label in y_train:
if label == 0:
a += 1
if label == 1:
b += 1
if label == 2:
c += 1
print (f'Number of Normal images = {a}')
print (f'Number of Bacteria images = {b}')
print (f'Number of Virus images = {c}')
# plotting the data
xe = [i for i, _ in enumerate(labels)]
numbers = [a,b,c]
plt.bar(xe,numbers,color = 'green')
plt.xlabel("Labels")
plt.ylabel("No. of images")
plt.title("Images for each label")
plt.xticks(xe, labels)
plt.show()
output -
So now we see. the training set has 1226 normal images, 2184 bacterial pneumonia images, and 1154 viral pneumonia images.
#check the difference from the majority classdifference_normal = b-a difference_virus = b-c print(difference_normal) print(difference_virus)
output —
958
1030
Solving the imbalance —
def rotate_images(image, scale =1.0, h=256, w = 256):
center = (h/2,w/2)
angle = random.randint(-25,25)
M = cv.getRotationMatrix2D(center, angle, scale)
rotated = cv.warpAffine(image, M, (h,w))
return rotated
def flip (image):
flipped = np.fliplr(image)
return flipped
def translation (image):
x= random.randint(-50,50)
y = random.randint(-50,50)
rows,cols,z = image.shape
M = np.float32([[1,0,x],[0,1,y]])
translate = cv.warpAffine(image,M,(cols,rows))
return translate
def blur (image):
x = random.randrange(1,5,2)
blur = cv.GaussianBlur(image,(x,x),cv.BORDER_DEFAULT)
return blur
We will be using 4 types of data augmentation methods, using the OpenCV library — 1. rotation- from -25 to +25 degrees at random, 2. flipping the images horizontally, 3. translation, with random settings both for the x and y-axis, 4. gaussian blurring at random.
For details on how to implement data augmentation using OpenCV please visit the following link — https://opencv.org
def apply_aug (image):
number = random.randint(1,4)
if number == 1:
image= rotate_images(image, scale =1.0, h=256, w = 256)
if number == 2:
image= flip(image)
if number ==3:
image= translation(image)
if number ==4:
image= blur(image)
return image
Next, we define another function, so that all the augmentations are applied completely randomly.
def oversample_images (difference_normal,difference_virus, X_train, y_train):
normal_counter = 0
virus_counter= 0
new_normal = []
new_virus = []
label_normal = []
label_virus = []
for i,item in enumerate (X_train):
if y_train[i] == 0 and normal_counter < difference_normal:
image = apply_aug(item)
normal_counter = normal_counter+1
label = 0
new_normal.append(image)
label_normal.append(label)
if y_train[i] == 2 and virus_counter < difference_virus:
image = apply_aug(item)
virus_counter = virus_counter+1
label =2
new_virus.append(image)
label_virus.append(label)
new_normal = np.array(new_normal)
label_normal = np.array(label_normal)
new_virus= np.array(new_virus)
label_virus = np.array(label_virus)
return new_normal, label_normal, new_virus, label_virus
This function, creates all the artificially augmented images for normal and viral pneumonia images, till they reach the difference in values from the total bacterial pneumonia images. It then returns the newly created normal and viral pneumonia images and labels.
n_images,n_labels,v_images,v_labels =oversample_images(difference_normal,difference_virus,X_train,y_train)print(n_images.shape) print(n_labels.shape) print(v_images.shape) print(v_labels.shape)
output —
We see that as expected, 958 normal images have been created and 1030 viral pneumonia images have been created.
Let's visualize a few of the artificial normal images,
# Extract 9 random images
print('Display Random Images')
# Adjust the size of your images
plt.figure(figsize=(20,10))
for i in range(9):
num = random.randint(0,len(n_images)-1)
plt.subplot(3, 3, i + 1)
plt.imshow(n_images[num],cmap='gray')
plt.axis('off')
# Adjust subplot parameters to give specified padding
plt.tight_layout()
output -
Next, let’s visualize a few of the artificial viral pneumonia images,
# Displays 9 generated viral images
# Extract 9 random images
print('Display Random Images')
# Adjust the size of your images
plt.figure(figsize=(20,10))
for i in range(9):
num = random.randint(0,len(v_images)-1)
plt.subplot(3, 3, i + 1)
plt.imshow(v_images[num],cmap='gray')
plt.axis('off')
# Adjust subplot parameters to give specified padding
plt.tight_layout()
output -
Each of those images generated above has some kind of augmentation — rotation, translation, flipping or blurring, all applied at random.
Next, we merge these artificial images and their labels with the original training dataset.
new_labels = np.append(n_labels,v_labels) y_new_labels = np.expand_dims(new_labels, axis=1) x_new_images = np.append(n_images,v_images,axis=0) X_train1 = np.append(X_train,x_new_images,axis=0) y_train1 = np.append(y_train,y_new_labels) print(X_train1.shape) print(y_train1.shape)
output —
Now, the training dataset has 6552 images.
bacteria_new=0
virus_new=0
normal_new =0
for i in y_train1:
if i==0:
normal_new = normal_new+1
elif i==1 :
bacteria_new = bacteria_new+1
else:
virus_new=virus_new+1
print ('Number of Normal images =',normal_new)
print ('Number of Bacteria images = ',bacteria_new)
print ('Number of Virus images =',virus_new)
# plotting the data
xe = [i for i, _ in enumerate(labels)]
numbers = [normal_new, bacteria_new, virus_new]
plt.bar(xe,numbers,color = 'green')
plt.xlabel("Labels")
plt.ylabel("No. of images")
plt.title("Images for each label")
plt.xticks(xe, labels)
plt.show()
output —
So finally, we have a balance in the training dataset. We have 2184 images in all the three classes.
So this is how we solved the Class Imbalance Problem. Feel free to try other methods and compare them with the final results.
Now that the class imbalance problem is dealt with in the next part we will look into image normalization and data augmentation using Keras and TensorFlow.
以上所述就是小编给大家介绍的《Deep Learning in Healthcare — X-Ray Imaging (Part 4-The Class Imbalance problem)》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!
猜你喜欢:本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。