U-net
Last updated
Last updated
Unets are for problems where the size of your output is the same as the size of your input, and kind of aligned with it. You would use Unets for any kind of generative modeling (this include segmentation, which is kind of like generative modeling - it's generating a picture which is a mask of the original objects).
For a classifying, it makes no sense to use Unet - you just want the down sampling path.
If we take the Camvid dataset - we're using Unet. We start with pre-training. We start with a RESNET34.
Original U-net paper - they start with an image of 572x572x1 - output 64 channels. Then they half the size many times and end up with 1024 channels, and 28x28 image. That's it down sampling path.
However, we need to end up with an image that is as big as the original. How do we do computation that increases the grid size? We do a stride HALF conv aka a deconvolution or transpose convolution.
It looks like this
Your inpute in a 2x2 image. You add 2 pixels of padding on the outside, as well as pixels that separate the input. So with a 3x3 kernel you'll end up going from 2x2 to a 5x5 output. If you put only 1 pixel of padding you end up with a 4x4 output.
This is how you increase resolution.
But this isnt a great approach - there are a lot of 0's - waste of time and computation. Also parts of the conv have access to different amounts of information - sometimes kernels only have 1 pixel of the input, sometimes 2 or 3...
Nowadays ppl do this:
On top of this we would do a stride 1 convolution. No 0's, you get a mix of a's and b's, etc.
You can also use bilinear interperlation - you take a weighted average of what goes in.
But if you're only using these techniques it doesn't work very well, the image lacks fine details.
What you can do is add a skip connection (or an identity connection). Rather than adding a skip connection that skips every 2 convolutions, they add skip connections where the grey lines are in the picture above.
You add a skip connection in the downsampling path to the same size bit, or image in the upsampling path. Also they didn't add, they concatenated - that's why we can see the white and the blue together in the image. They are kind of like dense blocks.
For example in this section:
You litterally have the original input pixels comming into the computation of the last few layers. This is going to make it handy to figure out the fine details in the segmentation task because we have all the pixels.
On the downside - we don't have many layers of computation going on. Only 4 layers.
The key thing in the code is the encoder fast ai uses, which refers to this:
They use a better encoder, which is a Resnet34 - the original U-net doesn't use this.
Our layers start with an encoder, then batchnorm, ReLU, then middle_conv - which is conv_layer()... see code above. conv_layer is basically a conv2d, batchnorm, Relu combo.
Middle conv is the 2 extra steps at the bottom of the U-net after then encoder - doing a bit of computation. It's nice to add layers of computation where you can.
Then we go through a set of indexes which are the indexes of what is the layer number where each of the stride 2 convs occurs and we store the index:
We can loop through that, and for each one of those points, create a U-net block. For this class we pass in how many upsampling channels there are - up_in_c, and how many cross connections x_in_c (skip connections).
They also use other tweaks - ex: in the upsampling they use a pixelshuffle, CNR, they also add another x-connection (called last_cross in the code). It takes the input pixels and feed it as a cross connection, not just the output of the 1st layer.
Most of the work is happening in the UNetBlock - It has to store the activations at each of the downsampling points. the way to do that is with a hook. So they put hooks into the Resnet34, to store the activations each time there is a stride 2conv. What they do with it is concatenate the upsampled convolution with the result of the hook (that they chuck through batchnorm (self.bn(s)). Then they do convolutions to it.