diff --git a/README.md b/README.md
index 95d24ee2..a695dbd5 100644
--- a/README.md
+++ b/README.md
@@ -394,6 +394,19 @@ Note: In the official github repo the s0 variant has additional num_conv_branche
 </div>
 </details>
 
+<details>
+<summary style="margin-left: 25px;">SAM</summary>
+<div style="margin-left: 25px;">
+
+| Encoder   | Weights  | Params, M |
+|-----------|:--------:|:---------:|
+| sam-vit_b |  sa-1b   |    91M    |
+| sam-vit_l |  sa-1b   |   308M    |
+| sam-vit_h |  sa-1b   |   636M    |
+
+</div>
+</details>
+
 
 \* `ssl`, `swsl` - semi-supervised and weakly-supervised learning on ImageNet ([repo](https://github.com/facebookresearch/semi-supervised-ImageNet1K-models)).
 
diff --git a/docs/encoders.rst b/docs/encoders.rst
index d64607b8..55946e6e 100644
--- a/docs/encoders.rst
+++ b/docs/encoders.rst
@@ -361,3 +361,16 @@ MobileOne
 +-----------------+----------+------------+
 | mobileone\_s4   | imagenet | 13.6M      |
 +-----------------+----------+------------+
+
+SAM
+~~~~~~~~~~~~~~~~~~~~~
+
++-----------------+----------+------------+
+| Encoder         | Weights  | Params, M  |
++=================+==========+============+
+| sam-vit_b       | sa-1b    | 91M        |
++-----------------+----------+------------+
+| sam-vit_l       | sa-1b    | 308M       |
++-----------------+----------+------------+
+| sam-vit_h       | sa-1b    | 636M       |
++-----------------+----------+------------+
diff --git a/docs/models.rst b/docs/models.rst
index 47de61ee..a5ab52c1 100644
--- a/docs/models.rst
+++ b/docs/models.rst
@@ -36,5 +36,3 @@ DeepLabV3
 DeepLabV3+
 ~~~~~~~~~~
 .. autoclass:: segmentation_models_pytorch.DeepLabV3Plus
-
-
diff --git a/requirements.txt b/requirements.txt
index 5f1a53ac..9e6cd5a7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,7 +2,7 @@ torchvision>=0.5.0
 pretrainedmodels==0.7.4
 efficientnet-pytorch==0.7.1
 timm==0.9.7
-
+segment-anything-py==1.0
 tqdm
 pillow
 six
diff --git a/segmentation_models_pytorch/encoders/__init__.py b/segmentation_models_pytorch/encoders/__init__.py
index 7551153f..635f44b4 100644
--- a/segmentation_models_pytorch/encoders/__init__.py
+++ b/segmentation_models_pytorch/encoders/__init__.py
@@ -4,6 +4,7 @@
 
 from .resnet import resnet_encoders
 from .dpn import dpn_encoders
+from .sam import sam_vit_encoders, SamVitEncoder
 from .vgg import vgg_encoders
 from .senet import senet_encoders
 from .densenet import densenet_encoders
@@ -46,6 +47,34 @@
 encoders.update(timm_gernet_encoders)
 encoders.update(mix_transformer_encoders)
 encoders.update(mobileone_encoders)
+encoders.update(sam_vit_encoders)
+
+
+def get_pretrained_settings(encoders: dict, encoder_name: str, weights: str) -> dict:
+    """Get pretrained settings for encoder from encoders collection.
+
+    Args:
+        encoders: collection of encoders
+        encoder_name: name of encoder in collection
+        weights: one of ``None`` (random initialization), ``imagenet`` or other pretrained settings
+
+    Returns:
+        pretrained settings for encoder
+
+    Raises:
+        KeyError: in case of wrong encoder name or pretrained settings name
+    """
+    try:
+        settings = encoders[encoder_name]["pretrained_settings"][weights]
+    except KeyError:
+        raise KeyError(
+            "Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format(
+                weights,
+                encoder_name,
+                list(encoders[encoder_name]["pretrained_settings"].keys()),
+            )
+        )
+    return settings
 
 
 def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):
@@ -69,19 +98,11 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **
 
     params = encoders[name]["params"]
     params.update(depth=depth)
+    params.update(kwargs)
     encoder = Encoder(**params)
 
     if weights is not None:
-        try:
-            settings = encoders[name]["pretrained_settings"][weights]
-        except KeyError:
-            raise KeyError(
-                "Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format(
-                    weights,
-                    name,
-                    list(encoders[name]["pretrained_settings"].keys()),
-                )
-            )
+        settings = get_pretrained_settings(encoders, name, weights)
         encoder.load_state_dict(model_zoo.load_url(settings["url"]))
 
     encoder.set_in_channels(in_channels, pretrained=weights is not None)
diff --git a/segmentation_models_pytorch/encoders/_base.py b/segmentation_models_pytorch/encoders/_base.py
index aab838f1..fee8d177 100644
--- a/segmentation_models_pytorch/encoders/_base.py
+++ b/segmentation_models_pytorch/encoders/_base.py
@@ -1,8 +1,3 @@
-import torch
-import torch.nn as nn
-from typing import List
-from collections import OrderedDict
-
 from . import _utils as utils
 
 
diff --git a/segmentation_models_pytorch/encoders/sam.py b/segmentation_models_pytorch/encoders/sam.py
new file mode 100644
index 00000000..aac722ba
--- /dev/null
+++ b/segmentation_models_pytorch/encoders/sam.py
@@ -0,0 +1,164 @@
+import math
+import warnings
+from typing import Mapping, Any
+
+import torch
+from segment_anything.modeling import ImageEncoderViT
+from torch import nn
+from segment_anything.modeling.common import LayerNorm2d
+
+from segmentation_models_pytorch.encoders._base import EncoderMixin
+
+
+class SamVitEncoder(EncoderMixin, ImageEncoderViT):
+    def __init__(self, **kwargs):
+        self._vit_depth = kwargs.pop("vit_depth")
+        self._encoder_depth = kwargs.get("depth", 5)
+        kwargs.update({"depth": self._vit_depth})
+        super().__init__(**kwargs)
+        self._out_chans = kwargs.get("out_chans", 256)
+        self._patch_size = kwargs.get("patch_size", 16)
+        self._embed_dim = kwargs.get("embed_dim", 768)
+        self._validate()
+        self.intermediate_necks = nn.ModuleList(
+            [self.init_neck(self._embed_dim, out_chan) for out_chan in self.out_channels[:-1]]
+        )
+
+    @staticmethod
+    def init_neck(embed_dim: int, out_chans: int) -> nn.Module:
+        # Use similar neck as in ImageEncoderViT
+        return nn.Sequential(
+            nn.Conv2d(
+                embed_dim,
+                out_chans,
+                kernel_size=1,
+                bias=False,
+            ),
+            LayerNorm2d(out_chans),
+            nn.Conv2d(
+                out_chans,
+                out_chans,
+                kernel_size=3,
+                padding=1,
+                bias=False,
+            ),
+            LayerNorm2d(out_chans),
+        )
+
+    @staticmethod
+    def neck_forward(neck: nn.Module, x: torch.Tensor, scale_factor: float = 1) -> torch.Tensor:
+        x = x.permute(0, 3, 1, 2)
+        if scale_factor != 1.0:
+            x = nn.functional.interpolate(x, scale_factor=scale_factor, mode="bilinear")
+        return neck(x)
+
+    def requires_grad_(self, requires_grad: bool = True):
+        # Keep the intermediate necks trainable
+        for param in self.parameters():
+            param.requires_grad_(requires_grad)
+        for param in self.intermediate_necks.parameters():
+            param.requires_grad_(True)
+        return self
+
+    @property
+    def output_stride(self):
+        return 32
+
+    @property
+    def out_channels(self):
+        return [self._out_chans // (2**i) for i in range(self._encoder_depth + 1)][::-1]
+
+    def _validate(self):
+        # check vit depth
+        if self._vit_depth not in [12, 24, 32]:
+            raise ValueError(f"vit_depth must be one of [12, 24, 32], got {self._vit_depth}")
+        # check output
+        scale_factor = self._get_scale_factor()
+        if scale_factor != self._encoder_depth:
+            raise ValueError(
+                f"With patch_size={self._patch_size} and depth={self._encoder_depth}, "
+                "spatial dimensions of model output will not match input spatial dimensions. "
+                "It is recommended to set encoder depth=4 with default vit patch_size=16."
+            )
+
+    def _get_scale_factor(self) -> float:
+        """Input image will be downscale by this factor"""
+        return int(math.log(self._patch_size, 2))
+
+    def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
+        x = self.patch_embed(x)
+        if self.pos_embed is not None:
+            x = x + self.pos_embed
+
+        features = []
+        skip_steps = self._vit_depth // self._encoder_depth
+        scale_factor = self._get_scale_factor()
+        for i, blk in enumerate(self.blocks):
+            x = blk(x)
+            if i % skip_steps == 0:
+                # Double spatial dimension and halve number of channels
+                neck = self.intermediate_necks[i // skip_steps]
+                features.append(self.neck_forward(neck, x, scale_factor=2**scale_factor))
+                scale_factor -= 1
+
+        x = self.neck(x.permute(0, 3, 1, 2))
+        features.append(x)
+
+        return features
+
+    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True) -> None:
+        # Exclude mask_decoder and prompt encoder weights
+        # and remove 'image_encoder.' prefix
+        state_dict = {
+            k.replace("image_encoder.", ""): v
+            for k, v in state_dict.items()
+            if not k.startswith("mask_decoder") and not k.startswith("prompt_encoder")
+        }
+        missing, unused = super().load_state_dict(state_dict, strict=False)
+        missing = list(filter(lambda x: not x.startswith("intermediate_necks"), missing))
+        if len(missing) + len(unused) > 0:
+            n_loaded = len(state_dict) - len(missing) - len(unused)
+            warnings.warn(
+                f"Only {n_loaded} out of pretrained {len(state_dict)} SAM image encoder modules are loaded. "
+                f"Missing modules: {missing}. Unused modules: {unused}."
+            )
+
+
+sam_vit_encoders = {
+    "sam-vit_h": {
+        "encoder": SamVitEncoder,
+        "pretrained_settings": {
+            "sa-1b": {"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"},
+        },
+        "params": dict(
+            embed_dim=1280,
+            vit_depth=32,
+            num_heads=16,
+            global_attn_indexes=[7, 15, 23, 31],
+        ),
+    },
+    "sam-vit_l": {
+        "encoder": SamVitEncoder,
+        "pretrained_settings": {
+            "sa-1b": {"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth"},
+        },
+        "params": dict(
+            embed_dim=1024,
+            vit_depth=24,
+            num_heads=16,
+            global_attn_indexes=[5, 11, 17, 23],
+        ),
+    },
+    "sam-vit_b": {
+        "encoder": SamVitEncoder,
+        "pretrained_settings": {
+            "sa-1b": {"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"},
+        },
+        "params": dict(
+            embed_dim=768,
+            vit_depth=12,
+            num_heads=12,
+            global_attn_indexes=[2, 5, 8, 11],
+        ),
+    },
+}
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/test_models.py b/tests/test_models.py
index c2e6d941..08e87b91 100644
--- a/tests/test_models.py
+++ b/tests/test_models.py
@@ -14,6 +14,9 @@ def get_encoders():
         "resnext101_32x16d",
         "resnext101_32x32d",
         "resnext101_32x48d",
+        "sam-vit_h",
+        "sam-vit_l",
+        "sam-vit_b",
     ]
     encoders = smp.encoders.get_encoder_names()
     encoders = [e for e in encoders if e not in exclude_encoders]
diff --git a/tests/test_sam.py b/tests/test_sam.py
new file mode 100644
index 00000000..2c377fa9
--- /dev/null
+++ b/tests/test_sam.py
@@ -0,0 +1,72 @@
+import pytest
+import torch
+
+import segmentation_models_pytorch as smp
+from segmentation_models_pytorch.encoders import get_encoder
+from tests.test_models import get_sample, _test_forward, _test_forward_backward
+
+
+@pytest.mark.parametrize("encoder_name", ["sam-vit_b", "sam-vit_l"])
+@pytest.mark.parametrize("img_size", [64, 128])
+@pytest.mark.parametrize("patch_size,depth", [(8, 3), (16, 4)])
+@pytest.mark.parametrize("vit_depth", [12, 24])
+def test_sam_encoder(encoder_name, img_size, patch_size, depth, vit_depth):
+    encoder = get_encoder(encoder_name, img_size=img_size, patch_size=patch_size, depth=depth, vit_depth=vit_depth)
+    assert encoder.output_stride == 32
+    assert encoder.out_channels == [256 // (2**i) for i in range(depth + 1)][::-1]
+
+    sample = torch.ones(1, 3, img_size, img_size)
+    with torch.no_grad():
+        out = encoder(sample)
+
+    assert len(out) == depth + 1
+
+    expected_spatial_size = img_size // patch_size
+    expected_chans = 256
+    for i in range(1, len(out)):
+        assert out[-i].size() == torch.Size([1, expected_chans, expected_spatial_size, expected_spatial_size])
+        expected_spatial_size *= 2
+        expected_chans //= 2
+
+
+def test_sam_encoder_trainable():
+    encoder = get_encoder("sam-vit_b", depth=4)
+
+    encoder.requires_grad_(False)
+    for name, param in encoder.named_parameters():
+        if name.startswith("intermediate_necks"):
+            assert param.requires_grad
+        else:
+            assert not param.requires_grad
+
+    encoder.requires_grad_(True)
+    for param in encoder.parameters():
+        assert param.requires_grad
+
+
+def test_sam_encoder_validation_error():
+    with pytest.raises(ValueError):
+        get_encoder("sam-vit_b", img_size=64, patch_size=16, depth=5, vit_depth=12)
+        get_encoder("sam-vit_b", img_size=64, patch_size=16, depth=4, vit_depth=None)
+        get_encoder("sam-vit_b", img_size=64, patch_size=16, depth=4, vit_depth=6)
+
+
+@pytest.mark.parametrize("model_class", [smp.Unet])
+@pytest.mark.parametrize("decoder_channels,encoder_depth", [([64, 32, 16, 8], 4), ([64, 32, 16, 8], 4)])
+def test_sam_encoder_arch(model_class, decoder_channels, encoder_depth):
+    img_size = 1024
+    model = model_class(
+        "sam-vit_b",
+        encoder_weights=None,
+        encoder_depth=encoder_depth,
+        decoder_channels=decoder_channels,
+    )
+    smp = torch.ones(1, 3, img_size, img_size)
+    _test_forward_backward(model, smp, test_shape=True)
+
+
+@pytest.mark.skip(reason="Run this test manually as it needs to download weights")
+def test_sam_encoder_weights():
+    smp.create_model(
+        "unet", encoder_name="sam-vit_b", encoder_depth=4, encoder_weights="sa-1b", decoder_channels=[64, 32, 16, 8]
+    )