JulienBeaulieu
  • Introduction
  • Sciences
    • Math
      • Probability
        • Bayes Rule
        • Binomial distribution
        • Conditional Probability
      • Statistics
        • Descriptive Statistics
        • Inferential Statistics
          • Normal Distributions
          • Sampling Distributions
          • Confidence Intervals
          • Hypothesis Testing
          • AB Testing
        • Simple Linear Regression
        • Multiple Linear Regression
          • Statistical learning course
          • Model Assumptions And How To Address Each
        • Logistic Regression
      • Calculus
        • The big picture of Calculus
          • Derivatives
          • 2nd derivatives
          • The exponential e^x
        • Calculus
        • Gradient
      • Linear Algebra
        • Matrices
          • Matrix Multiplication
          • Inverses and Transpose and permutations
        • Vector Space and subspaces
        • Orthogonality
          • Orthogonal Sets
          • Projections
          • Least Squares
        • Gaussian Elimination
    • Programming
      • Command Line
      • Git & GitHub
      • Latex
      • Linear Algebra
        • Element-wise operations, Multiplication Transpose
      • Encodings and Character Sets
      • Uncategorized
      • Navigating Your Working Directory and File I/O
      • Python
        • Problem Solving
        • Strings
        • Lists & Dictionaries
        • Storing Data
        • HTTP Requests
      • SQL
        • Basic Statements
        • Entity Relationship Diagram
      • Jupyter Notebooks
      • Data Analysis
        • Data Visualization
          • Data Viz Cheat Sheet
          • Explanatory Analysis
          • Univariate Exploration of Data
            • Bar Chart
            • Pie Charts
            • Histograms
            • Kernel Density Estimation
            • Figures, Axes, and Subplots
            • Choosing a Plot for Discrete Data
            • Scales and Transformations (Log)
          • Bivariate Exploration of Data
            • Scatterplots
            • Overplotting, Transparency, and Jitter
            • Heatmaps
            • Violin & Box Plots
            • Categorical Variable Analysis
            • Faceting
            • Line Plots
            • Adapted Bar Charts
            • Q-Q, Swarm, Rug, Strip, Stacked, and Rigeline Plots
          • Multivariate Exploration of Data
            • Non-Positional Encodings for Third Variables
            • Color Palettes
            • Faceting for Multivariate Data
            • Plot and Correlation Matrices
            • Other Adaptations of Bivariate PLots
            • Feature Engineering for Data Viz
        • Python - Cheat Sheet
    • Machine Learning
      • Courses
        • Practical Deep learning for coders
          • Convolutional Neural Networks
            • Image Restauration
            • U-net
          • Lesson 1
          • Lesson 2
          • Lesson 3
          • Lesson 4 NLP, Collaborative filtering, Embeddings
          • Lesson 5 - Backprop, Accelerated SGD
          • Tabular data
        • Fast.ai - Intro to ML
          • Neural Nets
          • Business Applications
          • Class 1 & 2 - Random Forests
          • Lessons 3 & 4
      • Unsupervised Learning
        • Dimensionality Reduction
          • Independant Component Analysis
          • Random Projection
          • Principal Component Analysis
        • K-Means
        • Hierarchical Clustering
        • DBSCAN
        • Gaussian Mixture Model Clustering
        • Cluster Validation
      • Preprocessing
      • Machine Learning Overview
        • Confusion Matrix
      • Linear Regression
        • Feature Scaling and Normalization
        • Regularization
        • Polynomial Regression
        • Error functions
      • Decision Trees
      • Support Vector Machines
      • Training and Tuning
      • Model Evaluation Metrics
      • NLP
      • Neural Networks
        • Perceptron Algorithm
        • Multilayer Perceptron
        • Neural Network Architecture
        • Gradient Descent
        • Backpropagation
        • Training Neural Networks
  • Business
    • Analytics
      • KPIs for a Website
  • Books
    • Statistics
      • Practice Statistics for Data Science
        • Exploring Binary and Categorical Data
        • Data and Sampling Distributions
        • Statistical Experiments and Significance Testing
        • Regression and Prediction
        • Classification
        • Correlation
    • Pragmatic Thinking and Learning
      • Untitled
    • A Mind For Numbers: How to Excel at Math and Science
      • Focused and diffuse mode
      • Procrastination
      • Working memory and long term memory
        • Chunking
      • Importance of sleeping
      • Q&A with Terrence Sejnowski
      • Illusions of competence
      • Seeing the bigger picture
        • The value of a Library of Chunks
        • Overlearning
Powered by GitBook
On this page
  • Adding a Skip connection - What a U-net really is
  • Fastai U-net code:
  • UNetBlock

Was this helpful?

  1. Sciences
  2. Machine Learning
  3. Courses
  4. Practical Deep learning for coders
  5. Convolutional Neural Networks

U-net

PreviousImage RestaurationNextLesson 1

Last updated 5 years ago

Was this helpful?

Unets are for problems where the size of your output is the same as the size of your input, and kind of aligned with it. You would use Unets for any kind of generative modeling (this include segmentation, which is kind of like generative modeling - it's generating a picture which is a mask of the original objects).

For a classifying, it makes no sense to use Unet - you just want the down sampling path.

If we take the Camvid dataset - we're using Unet. We start with pre-training. We start with a RESNET34.

Original U-net paper - they start with an image of 572x572x1 - output 64 channels. Then they half the size many times and end up with 1024 channels, and 28x28 image. That's it down sampling path.

However, we need to end up with an image that is as big as the original. How do we do computation that increases the grid size? We do a stride HALF conv aka a deconvolution or transpose convolution.

It looks like this

Your inpute in a 2x2 image. You add 2 pixels of padding on the outside, as well as pixels that separate the input. So with a 3x3 kernel you'll end up going from 2x2 to a 5x5 output. If you put only 1 pixel of padding you end up with a 4x4 output.

This is how you increase resolution.

But this isnt a great approach - there are a lot of 0's - waste of time and computation. Also parts of the conv have access to different amounts of information - sometimes kernels only have 1 pixel of the input, sometimes 2 or 3...

Nowadays ppl do this:

On top of this we would do a stride 1 convolution. No 0's, you get a mix of a's and b's, etc.

You can also use bilinear interperlation - you take a weighted average of what goes in.

But if you're only using these techniques it doesn't work very well, the image lacks fine details.

Adding a Skip connection - What a U-net really is

What you can do is add a skip connection (or an identity connection). Rather than adding a skip connection that skips every 2 convolutions, they add skip connections where the grey lines are in the picture above.

You add a skip connection in the downsampling path to the same size bit, or image in the upsampling path. Also they didn't add, they concatenated - that's why we can see the white and the blue together in the image. They are kind of like dense blocks.

For example in this section:

You litterally have the original input pixels comming into the computation of the last few layers. This is going to make it handy to figure out the fine details in the segmentation task because we have all the pixels.

On the downside - we don't have many layers of computation going on. Only 4 layers.

Fastai U-net code:

The key thing in the code is the encoder fast ai uses, which refers to this:

They use a better encoder, which is a Resnet34 - the original U-net doesn't use this.

Our layers start with an encoder, then batchnorm, ReLU, then middle_conv - which is conv_layer()... see code above. conv_layer is basically a conv2d, batchnorm, Relu combo.

Middle conv is the 2 extra steps at the bottom of the U-net after then encoder - doing a bit of computation. It's nice to add layers of computation where you can.

Then we go through a set of indexes which are the indexes of what is the layer number where each of the stride 2 convs occurs and we store the index:

We can loop through that, and for each one of those points, create a U-net block. For this class we pass in how many upsampling channels there are - up_in_c, and how many cross connections x_in_c (skip connections).

They also use other tweaks - ex: in the upsampling they use a pixelshuffle, CNR, they also add another x-connection (called last_cross in the code). It takes the input pixels and feed it as a cross connection, not just the output of the 1st layer.

UNetBlock

Most of the work is happening in the UNetBlock - It has to store the activations at each of the downsampling points. the way to do that is with a hook. So they put hooks into the Resnet34, to store the activations each time there is a stride 2conv. What they do with it is concatenate the upsampled convolution with the result of the hook (that they chuck through batchnorm (self.bn(s)). Then they do convolutions to it.

Nearest Neightbor interperlation