Semantic Segmentation with U-Nets
Semantic segmentation is a complex task in computer vision, aiming to assign a class label to each pixel in an image. Unlike object detection, which identifies bounding boxes, semantic segmentation provides a pixel-wise classification, crucial for applications like autonomous driving and medical imaging.
For instance, in autonomous driving, it's used to identify drivable surfaces. In medical imaging, it helps in segmenting different anatomical structures or pathologies, such as in chest radiographs or MRI brain scans.
In medical imaging, given a chest X-ray, you may want to diagnose if someone has a certain condition, but what may be even more helpful to doctors, is if you can segment out in the image, exactly which pixels correspond to certain parts of the patient's anatomy.
In the image on the left, the lungs, the heart, and the clavicle, so the collarbones are segmented out using different colours. This segmentation can make it easier to spot irregularities and diagnose serious diseases and also help surgeons with planning out surgeries. In this example, a brain MRI scan is used for brain tumour detection. Manually segmenting out this tumor is very time-consuming and laborious, but if a learning algorithm can segment out the tumour automatically this saves radiologists a lot of time and this is a useful input as well for surgical planning.
Novikov et a1., 2017, Fully Convolutional Architectures for Multi-Class Segmentation in Chest Radiographs Dong et a1., 2017 , Automatic Brain Tumor Detection and Segmentation Using U-Net Based Fully Convolutional Networks
U-Net Architecture
U-Net, a powerful architecture for semantic segmentation, excels in tasks requiring detailed, pixel-level labeling.
Structure of U-Net
Instead of just giving a single class label or maybe a class label and coordinates specifying a bounding box, the neural network has to generate a whole matrix of labels.
One key step of semantic segmentation is that, whereas the dimensions of the image have been generally getting smaller as we go from left to right, it now needs to get bigger so they can gradually blow it back up to a full-size image, which is a size you want for the output.
To this:
Specifically, this is what a unit architecture looks like. As we go deeper into the unit, the height and width will go back up while the number of channels will decrease so the unit architecture looks like this until eventually you get your segmentation map of the cat.
To do this, we need to use the transpose convolution.
Transpose Convolutions in U-Nets
Transpose convolutions, a key component in the expansive path, are used to upsample the feature maps:
Normal Convolution:
Transpose Convolution:
In detail:
- Use convolution, in this case
- Padding
- Stride
In the regular convolution, you would take the filter and place it on top of the inputs and then multiply and sum up. In the transpose convolution, instead of placing the filter on the input, you would instead place a filter on the output and multiply by each number in the kernel to up-sample.
Where the red and the green boxes overlap, you add two values together:
There are multiple possible ways to take small inputs and turn it into bigger outputs, but the transpose convolution happens to be one that is effective and when you learn all the parameters of the filter here, this turns out to give good results when you put this in the context of the U-Net which is the learning algorithm will use now. This is the key building block of the U-Net architecture.
U-Net Architecture Intuition
The U-Net architecture comprises two main parts:
-
Contracting Path (Downsampling): Consists of convolutional layers followed by max-pooling layers, progressively reducing spatial dimensions while increasing feature depth.
-
Expansive Path (Upsampling): Uses transpose convolutions to increase spatial dimensions, coupled with skip connections from the contracting path to retain high-resolution features.
We add skip connections from the earlier layers to the later layers so that earlier block of activations is copied directly to this later block. For this next to final layer to decide which region is a cat, two types of information are useful:
- High level spatial & contextual information which it gets from the previous layer. Where hopefully the neural network, would have figured out e.g. that in the lower right hand corner of the image or maybe in the right part of the image, there's some cat like stuff going on. But what is missing is a very detailed, fine grained spatial information. Because this set of activations here has lower spatial resolution (height and width is just lower).
- What the skip connection does is it allows the neural network to take this very high-resolution, low-level feature information where it could capture for every pixel position, how much furry stuff is there in this pixel? And used to skip connection to pause that directly to this later layer.
This layer now has both the lower resolution, but high level, spatial, high level contextual information, as well as the low level but more detailed texture like information in order to make a decision as to whether a certain pixel is part of a cat or not.
Ronneberger et al., 2015, U-Net: Convolutional Networks for Biomedical Image Segmentation
The input to the U-Net is an image (say ) for three channels RGB channels. Here we will depict this image as a thin layer as in the diagram. The first part of the unit uses normal feed forward neural network convolutional layers. Here, black arrows denote a convolutional layer followed by a ReLu activation function. Here the next layer may have increased the number of channels a little bit, but the dimension is still height by width. We repeat this.
We use Max pooling to reduce the height and width, but we might end up with a set of activations where the height and width is lower, but maybe a thicker, so the number of channels is increasing.
Then we have two more layers of normal feed forward convolutions with a ReLu activation function, and then the apply Max pooling again. We repeat this.
Notice now that the height of this layer at the bottom of the U is very small. So we're going to start to apply transpose convolution layers, (denoted by the green arrow) in order to build the dimension of this neural network back up.
So, with the first transpose convolutional layer, you're going to get a set of activations that looks like that. In this example, we did not increase the height and width, but we did decrease the number of channels.
We add skip connection aswell (denoted by a grey arrow). The skip connection takes this set of activations and just copies it over to the right. And so, the set of activations you end up with is a combination: The light blue part comes from the transpose convolution, and the dark blue part is just copied over from the left.
To keep on building up the U-Net we are going to then apply a couple more layers of the regular convolutions, followed by our ReLu activation function (denoted by the black arrows) and then we apply another transpose convolutional layer.
With the green arrow we're going to start to increase the dimension, increase the height and width of this image. And so now the height is getting bigger. But here, too, we're going to apply a skip connection. So there's a grey arrow again where they take this set of activations and just copy it right there, over to the right.
More convolutional layers and other transpose convolution, skip connection. Once again, we're going to take this set of activations and copy it over to the right and then more convolutional layers, followed by another transpose convolution. Skip connection, copy that over. And now we're back to a set of activations that is the original input images height and width.
We're going to have a couple more layers of a normal feed forward convolutions, and then finally, to take this and map this to our segmentation map, we're going to use a 1x1 Convolution which Is denoted with that magenta arrow to finally give us this which is going to be our output.
The dimensions of this final output layer is going to be so the same dimensions as our original input by num classes i.e. . So if you have three classes to try and recognize, this will be three.
For every one of your pixels you have an array or a vector essentially of n classes numbers that tells you for our pixel how likely is that pixel to come from each of these different classes.
If you then take an arg_max
over these classes, then that's how you classify each of the pixels into one of the classes, and you can visualize it like the segmentation map shown on the right.