Background removal with deep learning

Original author: Gidi Shperber
  • Transfer

Translation Background removal with deep learning .

Over the past few years of work in the field of machine learning, we wanted to create real products based on machine learning.

A few months ago, after completing an excellent Fast.AI course , the stars coincided, and we had such an opportunity. Modern advances in deep learning technologies have made it possible to carry out much of what previously seemed impossible, new tools have appeared that have made the implementation process more accessible than ever.

We set the following goals:

  1. Improve our deep learning skills.
  2. Improve our AI-based product implementation skills.
  3. Create a useful product with market prospects.
  4. Have fun (and help our users have fun).
  5. Share experiences.

Based on the foregoing, we studied ideas that:

  1. Nobody has yet been able to implement (or implement properly).
  2. They will not be too complicated in planning and implementation - we assigned 2-3 months of work to the project with a load of 1 working day per week.
  3. They will have a simple and attractive user interface - we wanted to make a product that people will use, and not just for demonstration purposes.
  4. They will have available data for training - as any machine learning specialist knows, sometimes the data is more expensive than the algorithm.
  5. They will use advanced deep learning methods (which have not yet been brought to market by Google, Amazon or their friends on cloud platforms), but not too advanced (so that we can find some examples on the Internet).
  6. They will have the potential to achieve a result sufficient to bring the product to the market.

Our early assumptions were to take on some kind of medical project, since this area is very close to us, and we felt (and still feel) that there are a huge number of topics suitable for deep learning. However, we realized that we would encounter problems in collecting data and, possibly, with law and regulation, which contradicted our desire not to complicate the task for ourselves. Therefore, we decided to stick to plan B - to make a product to remove the background in the images.

Removing the background is a task that is easy to perform manually or almost manually (Photoshop, and even Power Point have such tools) if you use some kind of “marker” and border detection technology, see the example . However, fully automated background removal is a rather difficult task, and as far as we know, there is still no product that has achieved acceptable results (although there are those who are trying ).

What background will we remove? This question turned out to be important, since the more specific the model is in terms of objects, angles, and more, the higher the quality of the separation of the background and foreground. When we started our work, we thought broadly: a comprehensive background removal tool that automatically identifies the foreground and background in each type of image. But after training our first model, we realized that it is better to focus our efforts on a specific set of images. Therefore, we decided to focus on selfies and portraits of people.

Removing the background in a photograph of an (almost) person.

Selfie is a picture:

  • with a characteristic and oriented foreground (one or more “people”), which guarantees us a good separation between the object (face + upper body) and the background,
  • as well as with a constant angle and always with the same object (person).

Given these claims, we started research and implementation, spending many hours training to create an easy-to-use one-click background removal service.

The bulk of our work was teaching the model, but we could not afford to underestimate the importance of proper implementation. Good segmentation models are still not as compact as image classification models (for example, SqueezeNet ), and we have actively studied implementation options on both the server side and the browser side.

If you want to read more about the implementation process of our product, you can familiarize yourself with our implementation posts on the server side and on the client side .

If you want to learn about the model and the process of its training, continue reading here.

Semantic segmentation

When studying the tasks of deep learning and computer vision, reminiscent of the tasks before us, it is easy to understand that the best option for us is the task of semantic segmentation .

There are other strategies, such as dividing by depth , but they seemed to us not mature enough for our purposes.

Semantic segmentation is a well-known task of computer vision, one of the three most important, along with the classification and detection of objects. Segmentation, in fact, is a classification task, in the sense of distributing each pixel into classes. Unlike the classification or image detection models, the segmentation model does show some “understanding” of the images, that is, it not only says that “there is a cat in this image”, but also indicates where the cat is at the pixel level.

So how does segmentation work? To better understand, we will need to study some of the early work in this area.

The very first idea was to adapt some of the early classification networks, such as VGG and Alexnet. VGG (Visual Geometry Group) was in 2014 an advanced model for classifying images, and even today it is very useful due to its simple and clear architecture. When studying the early layers of VGG, you can see that high activation is inherent for categorization. The deeper layers have even stronger activation, nevertheless they are terrible in nature due to repeated pooling action. With all this in mind, it was hypothesized that classification-based training, with some modifications, could also be used to search / segment the object.

The early results of semantic segmentation appeared along with classification algorithms. В In this post you can see some rough segmentation results obtained using VGG:

Results of deeper layers:
Segmentation of the image of the bus, light purple (29) - this is a class of a school bus.

After bilinear resampling:
These results are obtained from a simple transformation (or maintenance) of a fully connected layer into its original form, preserving its spatial characteristics and obtaining a complete convolutional neural network. In the above example, we load the image 768 * 1024 into VGG and get the layer 24 * 32 * 1000. 24 * 32 is the image after pooling (32 each), and 1000 is the number of image-net classes from which we can get the above segmentation.

To improve prediction, the researchers simply used a bilinear layer with oversampling.

In FCN, the authors improved on the above idea. They connected several layers to get richer interpretations, which were called FCN-32, FCN-16 and FCN-8, in accordance with the oversampling frequency:
Adding some skip connections between the layers allowed us to predict smaller parts of the original image with encoding . Further training further improved the results.

This method did not show itself as badly as one might think, and proved that semantic segmentation with deep learning really has potential.

FCN Results.

FCN revealed the concept of segmentation, and researchers were able to test different architectures for this task. The basic idea has remained unchanged: the use of well-known architectures, oversampling and throughput connections are still present in more recent models.

You can read about the achievements in this area in several good posts: here , here and here . You may also notice that most architectures have an encoder-decoder scheme.

Returning to our project

After some research, we settled on three models available to us: FCN, Unet, and Tiramisu - these are very deep architectures such as “encoder-decoder”. We also had some thoughts about the mask-RCNN method, but its implementation was outside the scope of our project.

FCN did not seem relevant, as its results were not as good as we wanted (even as a starting point), but two other models showed good results: the main advantages of Unet and Tiramisu with CamVid dataset were their compactness and speed. Unet was pretty simple to implement (we used keras), but Tiramisu was also quite feasible. To get started with something, we used the good Tiramisu implementation described in the last lesson of Jeremy Howard’s deep learning course .

We started training these two models on some datasets. I must say that after we first tried Tiramisu, its results had much greater potential for us, since the model could capture sharp edges of the image. Unet, in turn, was not good enough, and the results looked a bit blurry.

Blur Unet.


Having decided on the model, we began to look for suitable datasets. Data for segmentation is not as common as data for classification, or even for detection. In addition, it was not possible for us to index images manually. The most popular segmentation dataset were: COCO , which includes about 80 thousand images in 90 categories, VOC pascal with 11 thousand images in 20 classes, and the more recent ADE20K .

We decided to work with COCO, since it includes much more images of the “person” class that interested us.

Given our task, we thought about whether we would use only images that are relevant to us or a more “general” dataset. On the one hand, the use of a more general dataset with a large number of images and classes will allow the model to cope with more scenarios and tasks. On the other hand, overnight training allowed us to process ~ 150 thousand images. If we provide the entire COCO dataset to the model, then it will see each image twice (on average), so it’s better to crop the dataset a bit. In addition, our model will be better sharpened for our task.

Another point worth mentioning: the Tiramisu model was originally trained on the CamVid dataset, which has some drawbacks, the main of which is the strong uniformity of images: photos of roads made from cars. As you can understand, training on such a dataset (even if it contains people) did not bring us any benefit, so after some trials, we moved on.

Images from CamVid dataset.

The COCO dataset comes with a fairly simple API that allows us to know exactly which objects are in which image (according to 90 predefined classes).

After some experiments, we decided to dilute the dataset: at first, only images with a person were filtered out, leaving 40 thousand pictures. Then they discarded all images with several people, leaving only photos with 1-2 people, since our product is designed for such situations. Finally, we left only images in which a person occupies 20% - 70% of the area, deleting pictures with a too small person or with some strange monstrosities (unfortunately, we were not able to delete all of them). Our final dataset consisted of 11 thousand images, which, as we felt, were enough at this stage.

Left: a suitable image. In the center: too many participants. Right: Object too small.

Model Tiramisu

Although the full name of the Tiramisu model (“100 Tiramisu Layers”) implies a gigantic model, in fact it is quite economical and uses only 9 million parameters. In comparison, the VGG16 uses over 130 million parameters.

The Tiramisu model is based on DenseNet, the latest image classification model in which all layers are interconnected. In addition, through connections are added to Tiramisu oversampling layers, as in Unet.

If you recall, this architecture is consistent with the idea presented in FCN: using the architecture of classification, resampling and adding throughput connections for optimization.

This is what Tiramisu architecture looks like.

The DenseNet model can be considered as a natural evolution of the Resnet model, but instead of “remembering” each layer only until the next layer, Densenet remembers all layers in the entire model. Such connections are called highway connections. This leads to an increase in the number of filters, called the "growth rate" (growth rate). Tiramisu has a growth rate of 16, that is, with each layer we add 16 new filters until we reach layers of 1072 filters. You could expect 1,600 layers because it is a 100-layer Tiramisu model, however oversampling layers discard some filters.

Densenet Model Diagram - Early filters are stacked throughout the model.


We trained our model in accordance with the schedule described in the original document: standard cross-entropy loss, RMSProp optimizer with a learning coefficient of 1e-3 and a slight attenuation. We divided our 11 thousand images into three parts: 70% for training, 20% for verification and 10% for testing. All images below are taken from our test dataset.

To match our training schedule with the one in the source document, we set the sampling period at 500 images. It also allowed us to periodically save the model with each improvement in results, since we trained it on much more data than in the document (the CamVid dataset that was used in this article contained less than 1,000 images).

In addition, we trained our model using only two classes: background and person, and in the original document there were 12 classes. At first we tried to train COCO on some classes of the dataset, but we noticed that this does not lead to better results.

Data problems

Some flaws in the dataset lowered our assessment:

  • Animals . Our model sometimes segmented animals. This, of course, leads to a low IoU (intersection over union, intersection to union ratio). Adding animals to the main class or to a separate class would probably affect our results.
  • Body parts . Since we filtered our dataset programmatically, we were unable to determine whether the person’s class is really a person, and not a part of the body, such as an arm or leg. These images were not of interest to us, but nevertheless arose here and there.

    Animal, body part, portable object.
  • Portable objects . Many images in the dataset are related to sports. Baseball bats, tennis rackets and snowboards were everywhere. Our model was somehow "confused", not understanding how to segment it. As in the case of animals, in our opinion, adding them as part of the main class (or as a separate class) would help to improve the model.

    Sports image with an object.
  • Rough ground data . The COCO dataset was annotated not pixel by pixel, but using polygons. Sometimes this is enough, but in some cases the control data is too “rough”, which may prevent the model from learning the subtleties.

    The image itself and (very) rough control data.


Our results were satisfactory, although not perfect: we achieved an IoU of 84.6 on our test dataset, while a modern achievement is a value of 85 IoU. However, the specific value varies depending on the dataset and class. There are classes that are inherently easier to segment, such as homes or roads, where most models easily achieve 90 IoU. The more difficult classes are trees and people, on which most models achieve results of about 60 IoU. Therefore, we helped our network focus on one class and limited types of photographs.

We still don’t feel that our work is “ready for release” as we would like, but we believe that it is time to stop and discuss our achievements, since about 50% of the photos will give good results.

Here are some good examples to help you get the feel of the app:

Image - control data - our results (from our test dataset).

Debugging and Logging

Debugging is a very important part of learning neural networks. At the beginning of our work, it was very tempting to get down to business right away - to take data and network, start training and see what happens. However, we found that it is extremely important to track every step, examining the results at each step.

Here are common difficulties and our solutions:

  1. Early problems . The model cannot begin to learn. This may be due to some internal problem or a preprocessing error, for example, if you forget to normalize some pieces of data. In any case, a simple visualization of the results can be very useful. Here is a good post on this topic.
  2. Debugging the network itself . In the absence of serious problems, training begins with predefined losses and metrics. In segmentation, the main criterion is IoU - the ratio of intersection to union. It took us several sessions to start using IoU for our models as the main criterion (rather than loss of cross-entropy). Another useful practice has been to display the forecasting of our model in each sampling period. Here is a good article on debugging machine learning models. Please note that IoU is not a standard metric / loss in keras, but you can easily find it on the Internet, for example, here . We also used this gist for scheduling losses and some forecasting for each sampling period.
  3. Version control of machine learning . When training a model, there are many parameters, and some of them are very complex. I must say that we still have not found the ideal method, except that we enthusiastically fixed all our configurations (and automatically saved the best models with keras callback, see below).
  4. Debugging tool . After doing all of the above, we were able to analyze our work at every step, but not without difficulty. Therefore, the most important step was to combine the above steps and upload the data to the Jupyter Notebook (a tool for creating analytical reports), which allowed us to easily download each model and each image, and then quickly examine the results. Thus, we were able to see the differences between the models and detect pitfalls and other problems.

Here are examples of improvements to our model achieved through parameter settings and additional training:

To save the model with the best IoU result (to simplify the work, Keras allows you to make very good callbacks ):
In addition to the usual debugging of code errors, we noticed that the model errors are “predictable”. For example, “cutting off” body parts that are not counted as a body, “gaps” on large segments, excessive extensions of body parts, poor lighting, poor quality and many details. Some of these errors were bypassed by adding specific images from different datasets, and for some, no solution was found yet. To improve the results in the next version of the model, we will use augmentation for the “complex” for our image model.

We already mentioned this above (in the section on dataset problems), but now we will consider some of the difficulties in more detail:

  1. Clothes . Very dark or very light clothing is sometimes interpreted as background.
  2. "Clearances .
  3. " The results, good in everything else, sometimes had gaps in themselves.

    Clothes . зазоры.
  4. Lighting . Poor lighting and darkness are often found in images, but not in the COCO dataset. It’s generally difficult for models to work with such pictures, and our model wasn’t prepared for such complex images. You can try to solve this by adding more data, as well as by augmenting the data. In the meantime, it’s better not to try our app at night :)
    An example of poor lighting.

Options for Further Improvement

Continuing education

Our results were obtained after approximately 300 sampling cycles over our test data. After that, over-fitting began. We achieved such results very close to the release, so we did not have the opportunity to apply standard data augmentation practices.

We trained the model after we resized the images to 224x224. Further training with a large amount of data and larger images should also improve the results (the initial size of the COCO images is about 600x1000).

CRF and other improvements

At some stages, we noticed that our results are a bit “noisy” around the edges. A model that can handle this is CRF (Conditional random fields). In this post, the author provides a simplified example of using CRF.

However, we were of little use to her, perhaps because this model is usually useful when the results are rougher.


Even with our current results, segmentation is not perfect. Hair, thin clothes, tree branches and other small objects will never be segmented perfectly, if only because the segmentation of control data does not contain these nuances. The task of separating such delicate segmentation is called matting, and it also reveals other difficulties. Here is an example of modern matting published earlier this year at the NVIDIA conference.

Matting Example - Inputs include trimap.

The matting task is different from other tasks related to image processing, since the input data includes not only the image, but also trimap - the contour of the edges of the images, which makes matting a problem of "semi-controlled" training.

We experimented a bit with matting, using our segmentation as a trimap, but did not achieve significant results.

Another problem was the lack of a suitable dataset for training.


As stated at the beginning, our goal was to create a meaningful product through deep learning. As you can see in Alon's posts , implementation is becoming easier and faster. On the other hand, things are worse with model training — training, especially when it is done overnight, requires careful planning, debugging, and recording the results.

It is not easy to balance between research and attempts to do something new, as well as routine training and improvement. Since we use deep learning, we always have the feeling that a more advanced model is just around the corner, or just the model we need, and another Google search, or another read article will lead us to what we want. But in practice, our actual improvements were due to the fact that we “squeezed” more and more out of our original model. And we still feel that we can squeeze much more.

We had a lot of fun doing this work, which a few months ago seemed like science fiction.