when what not enough
In fact, sometimes it is important to distinguish between different kinds of objects. That car is speeding towards me. In that case, it would be better to jump out of the way, right? Or is it a giant Doberman (in which case I would do the same)? However, in real life, it is often not coarse-grained. classificationWhat you need is detail division.
When you zoom in on an image, you’re not looking for a single label. Instead we want to classify every pixel according to some criteria:
-
In medicine, you may want to distinguish between different cell types or identify a tumor.
-
Various earth sciences use satellite data to segment the earth’s surface.
-
To use a custom background, your video conferencing software must be able to distinguish between the foreground and background.
Image segmentation is a form of supervised learning. That means you need some kind of Ground Truth. Here it is presented in the following form: mask – An image with the same spatial resolution as the input data that specifies the actual class of every pixel. Therefore, the classification loss is calculated on a per-pixel basis. The losses are then summed to create an aggregate to be used for optimization.
The “standard” architecture for image segmentation is: UNET (since about 2015).
UNET
Here is the prototype U-Net depicted in the original Rönneberger et al. paper (Rohneberger, Fischer, and Brocks 2015).
Numerous variations of this architecture exist. You can use different layer sizes, ways to achieve activation, reduction and enlargement, etc. However, it has one defining characteristic. That is, it is U-shaped, stabilized by horizontally intersecting “bridges” at all levels.
Simply put, the left side of U is similar to the convolutional architecture used in image classification. Successively decreases spatial resolution. At the same time, another dimension channel Dimensions – Used to build a hierarchy of capabilities from very basic to very specialized.
However, unlike classification, the output must have the same spatial resolution as the input. So we have to scale up again. This is handled on the right side of U. But how will we reach good results? per pixel Classification, now that so much spatial information has been lost?
This is what a “bridge” is for. The input to the upsampling layer at each level is chain The output of the previous layer went through the full compression/decompression routine, with some preserved intermediate representations from the reduction step. In this way, the U-Net architecture combines attention to detail and feature extraction.
Brain image segmentation
With U-Net, the domain applicability is as wide as the flexibility of the architecture. Here we want to detect abnormalities in brain scans. Data sets used in Buda, Saha, and Mazurovsky (2019).contains MRI images with manually generated FLAIR anomaly segmentation masks. Available on Kaggle.
Fortunately, this paper comes with a GitHub repository. Below we closely follow (although do not replicate exactly) the author’s preprocessing and data augmentation code.
As is often the case with medical imaging, there is a noticeable class imbalance in the data. For all patients, sections were taken from multiple locations. (Number of sections per patient varies.) Most sections show no lesions. That mask is black everywhere.
Here are three examples of wearing a mask: do Indicates abnormal symptoms:
Let’s see if we can build a U-Net that generates such a mask.
data
Before you start typing, we have a Colaboratory notebook that you can conveniently follow along with.
we are pins
To get data. If you have never used the package before, please refer to this introduction.
The data set is not very large. Since it contains scans from 110 different patients, we only need to use the training and validation sets. (Don’t do this in real life, because you’ll inevitably end up fine-tuning the latter.)
train_dir <- "data/mri_train"
valid_dir <- "data/mri_valid"
if(dir.exists(train_dir)) unlink(train_dir, recursive = TRUE, force = TRUE)
if(dir.exists(valid_dir)) unlink(valid_dir, recursive = TRUE, force = TRUE)
zip::unzip(files, exdir = "data")
file.rename("data/kaggle_3m", train_dir)
# this is a duplicate, again containing kaggle_3m (evidently a packaging error on Kaggle)
# we just remove it
unlink("data/lgg-mri-segmentation", recursive = TRUE)
dir.create(valid_dir)
Of the 110 patients, 30 are kept for verification. With some more file manipulation, I’ve set up a nice hierarchy like this: train_dir
and valid_dir
Each patient has its own subdirectory.
valid_indices <- sample(1:length(patients), 30)
patients <- list.dirs(train_dir, recursive = FALSE)
for (i in valid_indices) {
dir.create(file.path(valid_dir, basename(patients(i))))
for (f in list.files(patients(i))) {
file.rename(file.path(train_dir, basename(patients(i)), f), file.path(valid_dir, basename(patients(i)), f))
}
unlink(file.path(train_dir, basename(patients(i))), recursive = TRUE)
}
Now we dataset
He knows what to do with this file.
dataset
like everyone torch
The data set, this is initialize()
and .getitem()
methods. initialize()
Create an inventory of scan and mask file names for use in: .getitem()
When you actually read that file. Unlike what you saw in the previous post, .getitem()
It doesn’t simply return input-target pairs in order. Instead, the parameter random_sampling
If true, we perform weighted sampling to favor items with significant lesions. This option is used on the training set to counteract the class imbalance mentioned above.
Another way to have different training and validation sets is to use data augmentation. Training images/masks can be flipped, resized, and rotated. Odds and amounts are configurable.
Instance of brainseg_dataset
It encapsulates all this functionality.
brainseg_dataset <- dataset(
name = "brainseg_dataset",
initialize = function(img_dir,
augmentation_params = NULL,
random_sampling = FALSE) {
self$images <- tibble(
img = grep(
list.files(
img_dir,
full.names = TRUE,
pattern = "tif",
recursive = TRUE
),
pattern = 'mask',
invert = TRUE,
value = TRUE
),
mask = grep(
list.files(
img_dir,
full.names = TRUE,
pattern = "tif",
recursive = TRUE
),
pattern = 'mask',
value = TRUE
)
)
self$slice_weights <- self$calc_slice_weights(self$images$mask)
self$augmentation_params <- augmentation_params
self$random_sampling <- random_sampling
},
.getitem = function(i) {
index <-
if (self$random_sampling == TRUE)
sample(1:self$.length(), 1, prob = self$slice_weights)
else
i
img <- self$images$img(index) %>%
image_read() %>%
transform_to_tensor()
mask <- self$images$mask(index) %>%
image_read() %>%
transform_to_tensor() %>%
transform_rgb_to_grayscale() %>%
torch_unsqueeze(1)
img <- self$min_max_scale(img)
if (!is.null(self$augmentation_params)) {
scale_param <- self$augmentation_params(1)
c(img, mask) %<-% self$resize(img, mask, scale_param)
rot_param <- self$augmentation_params(2)
c(img, mask) %<-% self$rotate(img, mask, rot_param)
flip_param <- self$augmentation_params(3)
c(img, mask) %<-% self$flip(img, mask, flip_param)
}
list(img = img, mask = mask)
},
.length = function() {
nrow(self$images)
},
calc_slice_weights = function(masks) {
weights <- map_dbl(masks, function(m) {
img <-
as.integer(magick::image_data(image_read(m), channels = "gray"))
sum(img / 255)
})
sum_weights <- sum(weights)
num_weights <- length(weights)
weights <- weights %>% map_dbl(function(w) {
w <- (w + sum_weights * 0.1 / num_weights) / (sum_weights * 1.1)
})
weights
},
min_max_scale = function(x) {
min = x$min()$item()
max = x$max()$item()
x$clamp_(min = min, max = max)
x$add_(-min)$div_(max - min + 1e-5)
x
},
resize = function(img, mask, scale_param) {
img_size <- dim(img)(2)
rnd_scale <- runif(1, 1 - scale_param, 1 + scale_param)
img <- transform_resize(img, size = rnd_scale * img_size)
mask <- transform_resize(mask, size = rnd_scale * img_size)
diff <- dim(img)(2) - img_size
if (diff > 0) {
top <- ceiling(diff / 2)
left <- ceiling(diff / 2)
img <- transform_crop(img, top, left, img_size, img_size)
mask <- transform_crop(mask, top, left, img_size, img_size)
} else {
img <- transform_pad(img,
padding = -c(
ceiling(diff / 2),
floor(diff / 2),
ceiling(diff / 2),
floor(diff / 2)
))
mask <- transform_pad(mask, padding = -c(
ceiling(diff / 2),
floor(diff /
2),
ceiling(diff /
2),
floor(diff /
2)
))
}
list(img, mask)
},
rotate = function(img, mask, rot_param) {
rnd_rot <- runif(1, 1 - rot_param, 1 + rot_param)
img <- transform_rotate(img, angle = rnd_rot)
mask <- transform_rotate(mask, angle = rnd_rot)
list(img, mask)
},
flip = function(img, mask, flip_param) {
rnd_flip <- runif(1)
if (rnd_flip > flip_param) {
img <- transform_hflip(img)
mask <- transform_hflip(mask)
}
list(img, mask)
}
)
After instantiation, we can see that there are 2977 training pairs and 952 validation pairs respectively.
train_ds <- brainseg_dataset(
train_dir,
augmentation_params = c(0.05, 15, 0.5),
random_sampling = TRUE
)
length(train_ds)
# 2977
valid_ds <- brainseg_dataset(
valid_dir,
augmentation_params = NULL,
random_sampling = FALSE
)
length(valid_ds)
# 952
Let’s draw an image and its associated mask to check accuracy.
par(mfrow = c(1, 2), mar = c(0, 1, 0, 1))
img_and_mask <- valid_ds(27)
img <- img_and_mask((1))
mask <- img_and_mask((2))
img$permute(c(2, 3, 1)) %>% as.array() %>% as.raster() %>% plot()
mask$squeeze() %>% as.array() %>% as.raster() %>% plot()
with torch
It is simple to examine what happens when you change augmentation-related parameters. Select a pair from the validation set that has not yet had any augmentation applied and call it. valid_ds$<augmentation_func()>
directly. Just for fun, we’ll use more “extreme” parameters here than we actually train with. (For actual training, we use settings from Mateusz’s GitHub repository, which we assume have been carefully selected for optimal performance.)
img_and_mask <- valid_ds(77)
img <- img_and_mask((1))
mask <- img_and_mask((2))
imgs <- map (1:24, function(i) {
# scale factor; train_ds really uses 0.05
c(img, mask) %<-% valid_ds$resize(img, mask, 0.2)
c(img, mask) %<-% valid_ds$flip(img, mask, 0.5)
# rotation angle; train_ds really uses 15
c(img, mask) %<-% valid_ds$rotate(img, mask, 90)
img %>%
transform_rgb_to_grayscale() %>%
as.array() %>%
as_tibble() %>%
rowid_to_column(var = "Y") %>%
gather(key = "X", value = "value", -Y) %>%
mutate(X = as.numeric(gsub("V", "", X))) %>%
ggplot(aes(X, Y, fill = value)) +
geom_raster() +
theme_void() +
theme(legend.position = "none") +
theme(aspect.ratio = 1)
})
plot_grid(plotlist = imgs, nrow = 4)
Now we still need a data loader and there’s nothing stopping us from moving on to the next big task: building the model.
batch_size <- 4
train_dl <- dataloader(train_ds, batch_size)
valid_dl <- dataloader(valid_ds, batch_size)
model
Our model nicely demonstrates the kind of modular code that comes “naturally.” torch
. We take a top-down approach, starting with the U-Net container itself.
unet
Manages global configuration. How do I go “down” to increase the number of filters while shrinking the image, and then move “up” again?
What’s important is that it’s also in system memory. in forward()
Tracks which layer outputs appear to be moving “down” and are added back when moving “up”.
unet <- nn_module(
"unet",
initialize = function(channels_in = 3,
n_classes = 1,
depth = 5,
n_filters = 6) {
self$down_path <- nn_module_list()
prev_channels <- channels_in
for (i in 1:depth) {
self$down_path$append(down_block(prev_channels, 2 ^ (n_filters + i - 1)))
prev_channels <- 2 ^ (n_filters + i -1)
}
self$up_path <- nn_module_list()
for (i in ((depth - 1):1)) {
self$up_path$append(up_block(prev_channels, 2 ^ (n_filters + i - 1)))
prev_channels <- 2 ^ (n_filters + i - 1)
}
self$last = nn_conv2d(prev_channels, n_classes, kernel_size = 1)
},
forward = function(x) {
blocks <- list()
for (i in 1:length(self$down_path)) {
x <- self$down_path((i))(x)
if (i != length(self$down_path)) {
blocks <- c(blocks, x)
x <- nnf_max_pool2d(x, 2)
}
}
for (i in 1:length(self$up_path)) {
x <- self$up_path((i))(x, blocks((length(blocks) - i + 1))$to(device = device))
}
torch_sigmoid(self$last(x))
}
)
unet
It delegates to the two containers immediately below it in the hierarchy. down_block
and up_block
. while down_block
It’s there “just” for aesthetic reasons (you can immediately delegate it to your own workers, conv_block
), in up_block
We see the U-Net “bridge” in action.
down_block <- nn_module(
"down_block",
initialize = function(in_size, out_size) {
self$conv_block <- conv_block(in_size, out_size)
},
forward = function(x) {
self$conv_block(x)
}
)
up_block <- nn_module(
"up_block",
initialize = function(in_size, out_size) {
self$up = nn_conv_transpose2d(in_size,
out_size,
kernel_size = 2,
stride = 2)
self$conv_block = conv_block(in_size, out_size)
},
forward = function(x, bridge) {
up <- self$up(x)
torch_cat(list(up, bridge), 2) %>%
self$conv_block()
}
)
finally, conv_block
It is a sequential structure containing Convolutional, ReLU and Dropout layers.
conv_block <- nn_module(
"conv_block",
initialize = function(in_size, out_size) {
self$conv_block <- nn_sequential(
nn_conv2d(in_size, out_size, kernel_size = 3, padding = 1),
nn_relu(),
nn_dropout(0.6),
nn_conv2d(out_size, out_size, kernel_size = 3, padding = 1),
nn_relu()
)
},
forward = function(x){
self$conv_block(x)
}
)
Now instantiate the model and move it to GPU if possible.
device <- torch_device(if(cuda_is_available()) "cuda" else "cpu")
model <- unet(depth = 5)$to(device = device)
Optimization
We train the model using a combination of cross-entropy and dice loss.
Although the latter does not ship together torch
You can implement it manually.
calc_dice_loss <- function(y_pred, y_true) {
smooth <- 1
y_pred <- y_pred$view(-1)
y_true <- y_true$view(-1)
intersection <- (y_pred * y_true)$sum()
1 - ((2 * intersection + smooth) / (y_pred$sum() + y_true$sum() + smooth))
}
dice_weight <- 0.3
The optimization uses stochastic gradient descent (SGD) with a single-cycle learning rate scheduler introduced in the context of image classification using Torch.
optimizer <- optim_sgd(model$parameters, lr = 0.1, momentum = 0.9)
num_epochs <- 20
scheduler <- lr_one_cycle(
optimizer,
max_lr = 0.1,
steps_per_epoch = length(train_dl),
epochs = num_epochs
)
training
The training loop then follows the usual scheme: Something to note: Save the model every epoch (using: torch_save()
) If you notice poor performance afterwards, you can choose the best one later.
train_batch <- function(b) {
optimizer$zero_grad()
output <- model(b((1))$to(device = device))
target <- b((2))$to(device = device)
bce_loss <- nnf_binary_cross_entropy(output, target)
dice_loss <- calc_dice_loss(output, target)
loss <- dice_weight * dice_loss + (1 - dice_weight) * bce_loss
loss$backward()
optimizer$step()
scheduler$step()
list(bce_loss$item(), dice_loss$item(), loss$item())
}
valid_batch <- function(b) {
output <- model(b((1))$to(device = device))
target <- b((2))$to(device = device)
bce_loss <- nnf_binary_cross_entropy(output, target)
dice_loss <- calc_dice_loss(output, target)
loss <- dice_weight * dice_loss + (1 - dice_weight) * bce_loss
list(bce_loss$item(), dice_loss$item(), loss$item())
}
for (epoch in 1:num_epochs) {
model$train()
train_bce <- c()
train_dice <- c()
train_loss <- c()
coro::loop(for (b in train_dl) {
c(bce_loss, dice_loss, loss) %<-% train_batch(b)
train_bce <- c(train_bce, bce_loss)
train_dice <- c(train_dice, dice_loss)
train_loss <- c(train_loss, loss)
})
torch_save(model, paste0("model_", epoch, ".pt"))
cat(sprintf("\nEpoch %d, training: loss:%3f, bce: %3f, dice: %3f\n",
epoch, mean(train_loss), mean(train_bce), mean(train_dice)))
model$eval()
valid_bce <- c()
valid_dice <- c()
valid_loss <- c()
i <- 0
coro::loop(for (b in tvalid_dl) {
i <<- i + 1
c(bce_loss, dice_loss, loss) %<-% valid_batch(b)
valid_bce <- c(valid_bce, bce_loss)
valid_dice <- c(valid_dice, dice_loss)
valid_loss <- c(valid_loss, loss)
})
cat(sprintf("\nEpoch %d, validation: loss:%3f, bce: %3f, dice: %3f\n",
epoch, mean(valid_loss), mean(valid_bce), mean(valid_dice)))
}
Epoch 1, training: loss:0.304232, bce: 0.148578, dice: 0.667423
Epoch 1, validation: loss:0.333961, bce: 0.127171, dice: 0.816471
Epoch 2, training: loss:0.194665, bce: 0.101973, dice: 0.410945
Epoch 2, validation: loss:0.341121, bce: 0.117465, dice: 0.862983
(...)
Epoch 19, training: loss:0.073863, bce: 0.038559, dice: 0.156236
Epoch 19, validation: loss:0.302878, bce: 0.109721, dice: 0.753577
Epoch 20, training: loss:0.070621, bce: 0.036578, dice: 0.150055
Epoch 20, validation: loss:0.295852, bce: 0.101750, dice: 0.748757
evaluation
In this run, it is the final model that performs best on the validation set. Still, we want to show how to load a saved model using: torch_load()
.
Once loaded, the model eval
method:
saved_model <- torch_load("model_20.pt")
model <- saved_model
model$eval()
Now that we don’t have a separate test set, we already know the out-of-sample average metric. But ultimately what we are interested in is the mask created. Let’s take a look at some, showing actual information and MRI scans for comparison.
# without random sampling, we'd mainly see lesion-free patches
eval_ds <- brainseg_dataset(valid_dir, augmentation_params = NULL, random_sampling = TRUE)
eval_dl <- dataloader(eval_ds, batch_size = 8)
batch <- eval_dl %>% dataloader_make_iter() %>% dataloader_next()
par(mfcol = c(3, 8), mar = c(0, 1, 0, 1))
for (i in 1:8) {
img <- batch((1))(i, .., drop = FALSE)
inferred_mask <- model(img$to(device = device))
true_mask <- batch((2))(i, .., drop = FALSE)$to(device = device)
bce <- nnf_binary_cross_entropy(inferred_mask, true_mask)$to(device = "cpu") %>%
as.numeric()
dc <- calc_dice_loss(inferred_mask, true_mask)$to(device = "cpu") %>% as.numeric()
cat(sprintf("\nSample %d, bce: %3f, dice: %3f\n", i, bce, dc))
inferred_mask <- inferred_mask$to(device = "cpu") %>% as.array() %>% .(1, 1, , )
inferred_mask <- ifelse(inferred_mask > 0.5, 1, 0)
img(1, 1, ,) %>% as.array() %>% as.raster() %>% plot()
true_mask$to(device = "cpu")(1, 1, ,) %>% as.array() %>% as.raster() %>% plot()
inferred_mask %>% as.raster() %>% plot()
}
It also prints the individual cross-entropy and dice loss. Correlating this with the generated mask provides useful information for tuning the model.
Sample 1, bce: 0.088406, dice: 0.387786}
Sample 2, bce: 0.026839, dice: 0.205724
Sample 3, bce: 0.042575, dice: 0.187884
Sample 4, bce: 0.094989, dice: 0.273895
Sample 5, bce: 0.026839, dice: 0.205724
Sample 6, bce: 0.020917, dice: 0.139484
Sample 7, bce: 0.094989, dice: 0.273895
Sample 8, bce: 2.310956, dice: 0.999824
While they’re far from perfect, most of these masks aren’t that bad. Considering the small data set, this is a good result!
summation
This was our most complex problem torch
Postings so far; But we hope you had a good time. First of all, among deep learning applications, medical image segmentation stands out as being very socially useful. Second, U-Net-like architectures are used in many other areas. And finally we saw one more time. torch
The flexibility and intuitive behavior makes it work.
Thanks for reading!
Buda, Mateusz, Ashirbani Saha, and Maciej A. Mazurowski. 2019. “Association of shape features automatically extracted by deep learning algorithm with genomic subtypes of low-grade glioma.” Biology and Medical Computers 109: 218–25. https://doi.org/https://doi.org/10.1016/j.compbiomed.2019.05.002.
Ronneberger, Olaf, Philip Fischer, and Thomas Brox. 2015. “U-Net: A convolutional network for biomedical image segmentation.” CoRR ABS/1505.04597. http://arxiv.org/abs/1505.04597.