How To Fine-Tune Segment Anything

Alexandre Bonnet
April 13, 2023
10 min read
blog image

Computer vision is having its ChatGPT moment with the release of the Segment Anything Model (SAM) by Meta last week. Trained over 11 billion segmentation masks, SAM is a foundation model for predictive AI use cases rather than generative AI. While it has shown an incredible amount of flexibility in its ability to segment over wide-ranging image modalities and problem spaces, it was released without “fine-tuning” functionality.

This tutorial will outline some of the key steps to fine-tune SAM using the mask decoder, particularly describing which functions from SAM to use to pre/post-process the data so that it's in good shape for fine-tuning.

Training CTA Asset
Supercharge your annotations by fine-tuning SAM for your use case
Book a live demo

What is the Segment Anything Model (SAM)?

The Segment Anything Model (SAM) is a segmentation model developed by Meta AI. It is considered the first foundational model for Computer Vision. SAM was trained on a huge corpus of data containing millions of images and billions of masks, making it extremely powerful. As its name suggests, SAM is able to produce accurate segmentation masks for a wide variety of images. SAM’s design allows it to take human prompts into account, making it particularly powerful for Human In The Loop annotation. These prompts can be multi-modal: they can be points on the area to be segmented, a bounding box around the object to be segmented, or a text prompt about what should be segmented.

The model is structured into 3 components: an image encoder, a prompt encoder, and a mask decoder.

Image displaying the foundation model architecture for the Segment Anything (SA) model

Source

The image encoder generates an embedding for the image being segmented, whilst the prompt encoder generates an embedding for the prompts. The image encoder is a particularly large component of the model. This is in contrast to the lightweight mask decoder, which predicts segmentation masks based on the embeddings. Meta AI has made the weights and biases of the model trained on the Segment Anything 1 Billion Mask (SA-1B) dataset available as a model checkpoint.

light-callout-cta Learn more about how Segment Anything works in our explainer blog post Segment Anything Model (SAM) Explained.

What is Model Fine-Tuning?

Publicly available state-of-the-art models have a custom architecture and are typically supplied with pre-trained model weights. If these architectures were supplied without weights then the models would need to be trained from scratch by the users, who would need to use massive datasets to obtain state-of-the-art performance.

Model fine-tuning is the process of taking a pre-trained model (architecture+weights) and showing it data for a particular use case. This will typically be data that the model hasn’t seen before, or that is underrepresented in its original training dataset.

The difference between fine-tuning the model and starting from scratch is the starting value of the weights and biases. If we were training from scratch, these would be randomly initialized according to some strategy. In such a starting configuration, the model would ‘know nothing’ of the task at hand and perform poorly. By using pre-existing weights and biases as a starting point we can ‘fine tune’ the weights and biases so that our model works better on our custom dataset. For example, the information learned to recognize cats (edge detection, counting paws) will be useful for recognizing dogs.

Why Would I Fine-Tune a Model?

The purpose of fine-tuning a model is to obtain higher performance on data that the pre-trained model has not seen before. For example, an image segmentation model trained on a broad corpus of data gathered from phone cameras will have mostly seen images from a horizontal perspective.

If we tried to use this model for satellite imagery taken from a vertical perspective, it may not perform as well. If we were trying to segment rooftops, the model may not yield the best results. The pre-training is useful because the model will have learned how to segment objects in general, so we want to take advantage of this starting point to build a model that can accurately segment rooftops. Furthermore, it is likely that our custom dataset would not have millions of examples, so we want to fine-tune instead of training the model from scratch.

Fine tuning is desirable so that we can obtain better performance on our specific use case, without having to incur the computational cost of training a model from scratch.

How to Fine-Tune Segment Anything Model [With Code]

Background & Architecture

We gave an overview of the SAM architecture in the introduction section. The image encoder has a complex architecture with many parameters. In order to fine-tune the model, it makes sense for us to focus on the mask decoder which is lightweight and therefore easier, faster, and more memory efficient to fine-tune.

In order to fine-tune SAM, we need to extract the underlying pieces of its architecture (image and prompt encoders, mask decoder). We cannot use SamPredictor.predict (link) for two reasons:

  • We want to fine-tune only the mask decoder
  • This function calls SamPredictor.predict_torch which has the  @torch.no_grad() decorator (link), which prevents us from computing gradients

Thus, we need to examine the SamPredictor.predict function and call the appropriate functions with gradient calculation enabled on the part we want to fine-tune (the mask decoder). Doing this is also a good way to learn more about how SAM works.

Creating a Custom Dataset

We need three things to fine-tune our model:

  • Images on which to draw segmentations
  • Segmentation ground truth masks
  • Prompts to feed into the model

We chose the stamp verification dataset (link) since it has data that SAM may not have seen in its training (i.e., stamps on documents). We can verify that it performs well, but not perfectly, on this dataset by running inference with the pre-trained weights. The ground truth masks are also extremely precise, which will allow us to calculate accurate losses. Finally, this dataset contains bounding boxes around the segmentation masks, which we can use as prompts to SAM. An example image is shown below. These bounding boxes align well with the workflow that a human annotator would go through when looking to generate segmentations.

Image of stamp with bounding box

Input Data Preprocessing

We need to preprocess the scans from numpy arrays to pytorch tensors. To do this, we can follow what happens inside SamPredictor.set_image (link) and SamPredictor.set_torch_image (link) which preprocesses the image. First, we can use utils.transform.ResizeLongestSide to resize the image, as this is the transformer used inside the predictor (link). We can then convert the image to a pytorch tensor and use the SAM preprocess method (link) to finish preprocessing.

Training Setup

We download the model checkpoint for the vit_b model and load them in:

sam_model = sam_model_registry['vit_b'](checkpoint='sam_vit_b_01ec64.pth')

We can set up an Adam optimizer with defaults and specify that the parameters to tune are those of the mask decoder:

optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters()) 

At the same time, we can set up our loss function, for example Mean Squared Error

loss_fn = torch.nn.MSELoss()

Training Loop

In the main training loop, we will be iterating through our data items, generating masks, and comparing them to our ground truth masks so that we can optimize the model parameters based on the loss function.

In this example, we used a GPU for training since it is much faster than using a CPU. It is important to use .to(device) on the appropriate tensors to make sure that we don’t have certain tensors on the CPU and others on the GPU.

We want to embed images by wrapping the encoder in the torch.no_grad() context manager, since otherwise we will have memory issues, along with the fact that we are not looking to fine-tune the image encoder.

with torch.no_grad():
	image_embedding = sam_model.image_encoder(input_image)

We can also generate the prompt embeddings within the no_grad context manager. We use our bounding box coordinates, converted to pytorch tensors.

with torch.no_grad():
      sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
          points=None,
          boxes=box_torch,
          masks=None,
      )

Finally, we can generate the masks. Note that here we are in single mask generation mode (in contrast to the 3 masks that are normally output).

low_res_masks, iou_predictions = sam_model.mask_decoder(
  image_embeddings=image_embedding,
  image_pe=sam_model.prompt_encoder.get_dense_pe(),
  sparse_prompt_embeddings=sparse_embeddings,
  dense_prompt_embeddings=dense_embeddings,
  multimask_output=False,
)

The final step here is to upscale the masks back to the original image size since they are low resolution. We can use Sam.postprocess_masks to achieve this. We will also want to generate binary masks from the predicted masks so that we can compare these to our ground truths. It is important to use torch functionals in order to not break backpropagation.

upscaled_masks = sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(device)

from torch.nn.functional import threshold, normalize

binary_mask = normalize(threshold(upscaled_masks, 0.0, 0)).to(device)

Finally, we can calculate the loss and run an optimization step:

loss = loss_fn(binary_mask, gt_binary_mask)
optimizer.zero_grad()
loss.backward()
optimizer.step()

By repeating this over a number of epochs and batches we can fine-tune the SAM decoder.

Saving Checkpoints and Starting a Model from it

Once we are done with training and satisfied with the performance uplift, we can save the state dict of the tuned model using:

torch.save(model.state_dict(), PATH)

We can then load this state dict when we want to perform inference on data that is similar to the data we used to fine-tune the model.

light-callout-cta You can find the Colab Notebook with all the code you need to fine-tune SAM here. Keep reading if you want a fully working solution out of the box!

Fine-Tuning for Downstream Applications

While SAM does not currently offer fine-tuning out of the box, we are building a custom fine-tuner integrated with the Encord platform. As shown in this post, we fine-tune the decoder in order to achieve this. This is available as an out-of-the-box one-click procedure in the web app, where the hyperparameters are automatically set.

Image displaying training the Segment Anything Model (SAM) in the Encord platform

Original vanilla SAM mask:

Image of the original vanilla SAM mask

Mask generated by fine-tuned version of the model:

Image of the mask generated by the fine tuned version of the model

We can see that this mask is tighter than the original mask. This was the result of fine-tuning on a small subset of images from the stamp verification dataset, and then running the tuned model on a previously unseen example. With further training and more examples, we could obtain even better results.

Conclusion

That's all, folks!

You have now learned how to fine-tune the Segment Anything Model (SAM). If you're looking to fine-tune SAM out of the box, you might also be interested to learn that we have recently released the Segment Anything Model in Encord, allowing you to fine-tune the model without writing any code.

Image of a frog segmented by the Segment Anything Model (SAM) inside the Encord platform

Supercharge Your Annotations with the
Segment Anything Model
medical banner

author-avatar-url
Written by Alexandre Bonnet
Alexandre Bonnet is a Machine Learning Solutions Engineer at Encord. He holds a Master's degree in Theoretical Physics from Imperial College and has previous industry experience in data engineering and data science. Alexandre has also been a member of Entrepreneur First's LD17 cohort as th... see more
View more posts
cta banner

Build better ML models with Encord

Get started today
cta banner

Discuss this blog on Slack

Join the Encord Developers community to discuss the latest in computer vision, machine learning, and data-centric AI

Join the community

Software To Help You Turn Your Data Into AI

Forget fragmented workflows, annotation tools, and Notebooks for building AI applications. Encord Data Engine accelerates every step of taking your model into production.