AWS Open Source Blog

Deploy fast.ai-trained PyTorch model in TorchServe and host in Amazon SageMaker inference endpoint

Over the past few years, fast.ai has become one of the most cutting-edge, open source, deep learning frameworks and the go-to choice for many machine learning use cases based on PyTorch. It has not only democratized deep learning and made it approachable to general audiences, but fast.ai has also become a role model on how scientific software should be engineered, especially in Python programming. However, to deploy a fast.ai model to production environment often involves setting up and self-maintaining a customized inference solution (e.g., with Flask), which is time-consuming and distracting to manage and maintain issues such as security, load balancing, services orchestration, etc.

Recently, in partnership with Facebook, Amazon Web Services (Amazon AWS) developed TorchServe, a flexible and easy-to-use, open source tool for serving PyTorch models. TorchServe removes the heavy lifting of deploying and serving PyTorch models with Kubernetes, and AWS and Facebook will maintain and continue contributing to TorchServe along with the broader PyTorch community. With TorchServe, many features work out of the box, and they provide full flexibility of deploying trained PyTorch models at scale so that a trained model can go into production with a few extra lines of code.

Meanwhile, Amazon SageMaker endpoint is a fully managed service that allows users to make real-time inferences via a REST API, which saves data scientists and machine learning engineers from managing their own server instances, load balancing, fault tolerance, auto-scaling, model monitoring, and others. Amazon SageMaker endpoint also provides different types of instances suitable for different tasks, including ones with GPU(s), which support industry-level machine learning inference and graphics-intensive applications while being cost-effective.

In this article, we demonstrate how to deploy a fast.ai-trained PyTorch model in TorchServe eager mode and host it in Amazon SageMaker inference endpoint.

Getting started with a fast.ai model

In this section, we train a fast.ai model that can solve a real-world problem with performance meeting the use-case specification. As an example, we focus on a “Scene Segmentation” use case from a self-driving car.

Installation

The first step is to install the fast.ai package, which is covered in the GitHub repository, as follows:

If you’re using Anaconda then run:

conda install -c fastai -c pytorch -c anaconda fastai gh anaconda

Or, if you’re using Miniconda then run:

conda install -c fastai -c pytorch fastai

For other installation options, please refer to the fast.ai documentation.

Modeling

The following materials are based on the fast.ai course called Practical Deep Learning for Coders.

First, import fastai.vision modules and download the sample data CAMVID_TINY by doing:

from fastai.vision.all import *
path = untar_data(URLs.CAMVID_TINY)

Second, define helper functions to calculate segmentation performance and read in segmentation mask for each training image.

Note: Defining one-line Python lambda functions to pass to fastai is tempting; however, this will introduce issues on serialization when we want to export a fast.ai model. Therefore, we avoid using anonymous Python functions during fast.ai modeling steps.

def acc_camvid(inp, targ, void_code=0):
    targ = targ.squeeze(1)
    mask = targ != void_code
    return (inp.argmax(dim=1)[mask] == targ[mask]).float().mean()

def get_y(o, path=path):
    return path / "labels" / f"{o.stem}_P{o.suffix}"

Third, we set up the DataLoader, which defines the modelling path, training image path, batch size, mask path, mask code, etc. In this example, we also record the image size and number of classes from the data. In a real-world problem these values may be known beforehand and will be defined when constructing the dataset.

dls = SegmentationDataLoaders.from_label_func(
    path,
    bs=8,
    fnames=get_image_files(path / "images"),
    label_func=get_y,
    codes=np.loadtxt(path / "codes.txt", dtype=str),
)
dls.one_batch()[0].shape[-2:], get_c(dls)
>>> (torch.Size([96, 128]), 32)

Next, we set up a U-Net learner with a Residual Neural Network (ResNet) backbone and then trigger the fast.ai training process.

learn = unet_learner(dls, resnet50, metrics=acc_camvid)
learn.fine_tune(20)
>>>
epoch	train_loss	valid_loss	acc_camvid	time
0	3.901105	2.671725	0.419333	00:04
epoch	train_loss	valid_loss	acc_camvid	time
0	1.732219	1.766196	0.589736	00:03
1	1.536345	1.550913	0.612496	00:02
2	1.416585	1.170476	0.650690	00:02
3	1.300092	1.087747	0.665566	00:02
4	1.334166	1.228493	0.649878	00:03
5	1.269190	1.047625	0.711870	00:02
6	1.243131	0.969567	0.719976	00:03
7	1.164861	0.988767	0.700076	00:03
8	1.103572	0.791861	0.787799	00:02
9	1.026181	0.721673	0.806758	00:02
10	0.949283	0.650206	0.815247	00:03
11	0.882919	0.696920	0.812805	00:03
12	0.823694	0.635109	0.824582	00:03
13	0.766428	0.631013	0.832627	00:02
14	0.715637	0.591066	0.839386	00:03
15	0.669535	0.601648	0.836554	00:03
16	0.628947	0.598065	0.840095	00:03
17	0.593876	0.578633	0.841116	00:02
18	0.563728	0.582522	0.841409	00:03
19	0.539064	0.580864	0.842272	00:02

Finally, we export the fast.ai model to use for following sections of this tutorial.

learn.export("./fastai_unet.pkl")

For more details about the modeling process, refer to the following AWS sample: notebook/01_U-net_Modelling.ipynb.

PyTorch transfer modeling from fast.ai

In this section we build a pure PyTorch model and transfer the model weights from fast.ai. The following materials are inspired by Practical-Deep-Learning-for-Coders-2.0 by Zachary Mueller et al.

Export model weights from fast.ai

First, we restore the fast.ai learner from the export “pickle” in the last section and save its model weights with PyTorch.

from fastai.vision.all import *
import torch

def acc_camvid(*_): pass
def get_y(*_): pass

learn = load_learner("/home/ubuntu/.fastai/data/camvid_tiny/fastai_unet.pkl")
torch.save(learn.model.state_dict(), "fasti_unet_weights.pth")

Obtaining the fast.ai prediction on a sample image is also straightforward.

“2013.04 – ‘Streetview of a small neighborhood’, with residential buildings, Amsterdam city photo by Fons Heijnsbroek, The Netherlands” by Amsterdam free photos & pictures of the Dutch city is marked under CC0 1.0. To view the terms, visit https://creativecommons.org/licenses/cc0/1.0/

Photo of a neighborhood.

image_path = "street_view_of_a_small_neighborhood.png"
pred_fastai = learn.predict(image_path)
pred_fastai[0].numpy()
>>>
array([[26, 26, 26, ...,  4,  4,  4],
       [26, 26, 26, ...,  4,  4,  4],
       [26, 26, 26, ...,  4,  4,  4],
       ...,
       [17, 17, 17, ..., 30, 30, 30],
       [17, 17, 17, ..., 30, 30, 30],
       [17, 17, 17, ..., 30, 30, 30]])

PyTorch model from fast.ai source code

Next, we need to define the model in pure PyTorch. In a Jupyter notebook, you can investigate the fast.ai source code by adding ?? in front of a function name. Here we look into unet_learner and DynamicUnet by doing:

>> ??unet_learner
>> ??DynamicUnet

Each of these commands will pop up a window at bottom of the browser:

Screenshot of the pop up window that appears after running the commands.

After investigating, the PyTorch model can be defined as:

from fastai.vision.all import *
from fastai.vision.learner import _default_meta
from fastai.vision.models.unet import _get_sz_change_idxs, UnetBlock, ResizeToOrig


class DynamicUnetDIY(SequentialEx):
    "Create a U-Net from a given architecture."

    def __init__(
        self,
        arch=resnet50,
        n_classes=32,
        img_size=(96, 128),
        blur=False,
        blur_final=True,
        y_range=None,
        last_cross=True,
        bottle=False,
        init=nn.init.kaiming_normal_,
        norm_type=None,
        self_attention=None,
        act_cls=defaults.activation,
        n_in=3,
        cut=None,
        **kwargs
    ):
        meta = model_meta.get(arch, _default_meta)
        encoder = create_body(
            arch, n_in, pretrained=False, cut=ifnone(cut, meta["cut"])
        )
        imsize = img_size

        sizes = model_sizes(encoder, size=imsize)
        sz_chg_idxs = list(reversed(_get_sz_change_idxs(sizes)))
        self.sfs = hook_outputs([encoder[i] for i in sz_chg_idxs], detach=False)
        x = dummy_eval(encoder, imsize).detach()

        ni = sizes[-1][1]
        middle_conv = nn.Sequential(
            ConvLayer(ni, ni * 2, act_cls=act_cls, norm_type=norm_type, **kwargs),
            ConvLayer(ni * 2, ni, act_cls=act_cls, norm_type=norm_type, **kwargs),
        ).eval()
        x = middle_conv(x)
        layers = [encoder, BatchNorm(ni), nn.ReLU(), middle_conv]

        for i, idx in enumerate(sz_chg_idxs):
            not_final = i != len(sz_chg_idxs) - 1
            up_in_c, x_in_c = int(x.shape[1]), int(sizes[idx][1])
            do_blur = blur and (not_final or blur_final)
            sa = self_attention and (i == len(sz_chg_idxs) - 3)
            unet_block = UnetBlock(
                up_in_c,
                x_in_c,
                self.sfs[i],
                final_div=not_final,
                blur=do_blur,
                self_attention=sa,
                act_cls=act_cls,
                init=init,
                norm_type=norm_type,
                **kwargs
            ).eval()
            layers.append(unet_block)
            x = unet_block(x)

        ni = x.shape[1]
        if imsize != sizes[0][-2:]:
            layers.append(PixelShuffle_ICNR(ni, act_cls=act_cls, norm_type=norm_type))
        layers.append(ResizeToOrig())
        if last_cross:
            layers.append(MergeLayer(dense=True))
            ni += in_channels(encoder)
            layers.append(
                ResBlock(
                    1,
                    ni,
                    ni // 2 if bottle else ni,
                    act_cls=act_cls,
                    norm_type=norm_type,
                    **kwargs
                )
            )
        layers += [
            ConvLayer(ni, n_classes, ks=1, act_cls=None, norm_type=norm_type, **kwargs)
        ]
        apply_init(nn.Sequential(layers[3], layers[-2]), init)
        # apply_init(nn.Sequential(layers[2]), init)
        if y_range is not None:
            layers.append(SigmoidRange(*y_range))
        super().__init__(*layers)

    def __del__(self):
        if hasattr(self, "sfs"):
            self.sfs.remove()

Also, we check the inheritance hierarchy of the fast.ai-defined class SequentialEx by:

SequentialEx.mro()
>>> [fastai.layers.SequentialEx,
 fastai.torch_core.Module,
 torch.nn.modules.module.Module,
 object]

Here we can see SequentialEx stems from the PyTorch torch.nn.modules; therefore, DynamicUnetDIY is a PyTorch Model.

Note: Parameters of arch, n_classes, img_size, etc. must be consistent with the training process. If other parameters are customized during training, they must be reflected here as well. Also, in the create_body, we set pretrained=False because we are transferring the weights from fast.ai. Thus, there is no need to download weights from PyTorch again.

Weights transfer

Now we can initialize the PyTorch model, load the saved model weights, and transfer the weights to the PyTorch model.

model_torch_rep = DynamicUnetDIY()
state = torch.load("fasti_unet_weights.pth")
model_torch_rep.load_state_dict(state)
model_torch_rep.eval();

If we take one sample image, transform it, and pass it to the model_torch_rep, we will get a prediction result identical to fast.ai’s.

from torchvision import transforms
from PIL import Image
import numpy as np

image_path = "street_view_of_a_small_neighborhood.png"

image = Image.open(image_path).convert("RGB")
image_tfm = transforms.Compose(
    [
        transforms.Resize((96, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

x = image_tfm(image).unsqueeze_(0)

# inference on CPU
raw_out = model_torch_rep(x)
raw_out.shape
>>> torch.Size([1, 32, 96, 128])

pred_res = raw_out[0].argmax(dim=0).numpy().astype(np.uint8)
pred_res
>>>
array([[26, 26, 26, ...,  4,  4,  4],
       [26, 26, 26, ...,  4,  4,  4],
       [26, 26, 26, ...,  4,  4,  4],
       ...,
       [17, 17, 17, ..., 30, 30, 30],
       [17, 17, 17, ..., 30, 30, 30],
       [17, 17, 17, ..., 30, 30, 30]], dtype=uint8)

np.all(pred_fastai[0].numpy() == pred_res)
>>> True

Here we can see the difference: The fast.ai model fastai_unet.pkl packages all the steps including the data transformation, image dimension alignment, etc. However, fasti_unet_weights.pth has only the pure weights, and we have to manually redefine the data transformation procedures, among others, and make sure they are consistent with the training step.

Note: In image_tfm, we need to make sure the image size and normalization statistics are consistent with the training step. In our example here, the size is 96x128, and normalization is by default from ImageNet as used in fast.ai. If other transformations were applied during training, they may need to be added here as well.

For more details about the PyTorch weights transferring process, please refer to this AWS sample: notebook/02_Inference_in_pytorch.ipynb.

Deployment to TorchServe

In this section, we deploy the PyTorch model to TorchServe. For installation, please refer to the TorchServe GitHub repository.

Overall, there are three main steps to use TorchServe:

  1. Archive the model into *.mar.
  2. Start the torchserve.
  3. Call the API and get the response.

To archive the model, at least three files are needed in our case:

  1. PyTorch model weights fasti_unet_weights.pth.
  2. PyTorch model definition model.py, which is identical to DynamicUnetDIY definition described in the last section.
  3. TorchServe custom handler.

Custom handler

As shown in /deployment/handler.py, the TorchServe handler accepts data and context. In our example, we define another helper Python class with four instance methods to implement: initialize, preprocess, inference, and postprocess.

initialize

Here we work out whether GPU is available, then identify the serialized model weights file path, and finally instantiate the PyTorch model and put it to evaluation mode.

    def initialize(self, ctx):
        """
        load eager mode state_dict based model
        """
        properties = ctx.system_properties
        self.device = torch.device(
            "cuda:" + str(properties.get("gpu_id"))
            if torch.cuda.is_available()
            else "cpu"
        )
        model_dir = properties.get("model_dir")

        manifest = ctx.manifest
        logger.error(manifest)
        serialized_file = manifest["model"]["serializedFile"]
        model_pt_path = os.path.join(model_dir, serialized_file)
        if not os.path.isfile(model_pt_path):
            raise RuntimeError("Missing the model definition file")

        logger.debug(model_pt_path)

        from model import DynamicUnetDIY

        state_dict = torch.load(model_pt_path, map_location=self.device)
        self.model = DynamicUnetDIY()
        self.model.load_state_dict(state_dict)
        self.model.to(self.device)
        self.model.eval()

        logger.debug("Model file {0} loaded successfully".format(model_pt_path))
        self.initialized = True

preprocess

As described in the previous section, we redefine the image transform steps and apply them to the inference data.

    def preprocess(self, data):
        """
        Scales and normalizes a PIL image for an U-net model
        """
        image = data[0].get("data")
        if image is None:
            image = data[0].get("body")

        image_transform = transforms.Compose(
            [
                transforms.Resize((96, 128)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )
        image = Image.open(io.BytesIO(image)).convert(
            "RGB"
        )
        image = image_transform(image).unsqueeze_(0)
        return image

inference

Now convert image into PyTorch Tensor, load it into GPU if available, and pass it through the model.

    def inference(self, img):
        """
        Predict the chip stack mask of an image using a trained deep learning model.
        """
        self.model.eval()
        inputs = Variable(img).to(self.device)
        outputs = self.model.forward(inputs)
        logging.debug(outputs.shape)
        return outputs

postprocess

Here the inference raw output is unloaded from GPU if available and encoded with Base64 to be returned back to the API trigger.

    def postprocess(self, inference_output):

        if torch.cuda.is_available():
            inference_output = inference_output[0].argmax(dim=0).cpu()
        else:
            inference_output = inference_output[0].argmax(dim=0)

        return [
            {
                "base64_prediction": base64.b64encode(
                    inference_output.numpy().astype(np.uint8)
                ).decode("utf-8")
            }
        ]

Now we’re ready to set up and launch TorchServe.

TorchServe in action

Step 1: Archive the model PyTorch.

>>> torch-model-archiver --model-name fastunet --version 1.0 --model-file deployment/model.py --serialized-file model_store/fasti_unet_weights.pth --export-path model_store --handler deployment/handler.py -f

Step 2: Serve the model.

>>> torchserve --start --ncs --model-store model_store --models fastunet.mar

Step 3: Call API and get the response. (Here we use httpie.) For a complete response, see sample/sample_output.txt.

>>> time http POST http://127.0.0.1:8080/predictions/fastunet/ @sample/street_view_of_a_small_neighborhood.png

HTTP/1.1 200
Cache-Control: no-cache; no-store, must-revalidate, private
Expires: Thu, 01 Jan 1970 00:00:00 UTC
Pragma: no-cache
connection: keep-alive
content-length: 131101
x-request-id: 96c25cb1-99c2-459e-9165-aa5ef9e3a439

{
  "base64_prediction": "GhoaGhoaGhoaGhoaGhoaGhoaGh...ERERERERERERERERERERER"
}

real    0m0.979s
user    0m0.280s
sys     0m0.039s

The first call would have longer latency due to model weights loading defined in initialize, but this will be mitigated from the second call onward. For more details about TorchServe setup and usage, please refer to notebook/03_TorchServe.ipynb.

Deployment to Amazon SageMaker inference endpoint

In this section, we deploy the fast.ai-trained Scene Segmentation PyTorch model with TorchServe in Amazon SageMaker endpoint using customized Docker image, and we will be using a ml.g4dn.xlarge instance. Refer to Amazon Elastic Compute Cloud (Amazon EC2) G4 Instances for more details.

Getting started with Amazon SageMaker endpoint

There are four steps to set up an Amazon SageMaker endpoint with TorchServe:

  1. Build a customized Docker image and push to Amazon Elastic Container Registry (Amazon ECR). The dockerfile is provided in root of this code repository, which helps set up fast.ai and TorchServe dependencies.
  2. Compress *.mar into *.tar.gz and upload to Amazon Simple Storage Service (Amazon S3).
  3. Create SageMaker model using the Docker image from step 1 and the compressed model weights from step 2.
  4. Create the SageMaker endpoint using the model from step 3.

The details of these steps are described in notebook/04_SageMaker.ipynb. Once ready, we can invoke the SageMaker endpoint with image in real-time.

Real-time inference with Python SDK

Read a sample image.

file_name = "street_view_of_a_small_neighborhood.png"

with open(file_name, 'rb') as f:
    payload = f.read()

Invoke the SageMaker endpoint with the image and obtain the response from the API.

client = boto3.client("runtime.sagemaker")
response = client.invoke_endpoint(
    EndpointName=endpoint_name, ContentType="application/x-image", Body=payload
)
response = json.loads(response["Body"].read())

Decode the response and visualize the predicted Scene Segmentation mask.

pred_decoded_byte = base64.decodebytes(bytes(response["base64_prediction"], encoding="utf-8"))
pred_decoded = np.reshape(
    np.frombuffer(pred_decoded_byte, dtype=np.uint8), (96, 128)
)
plt.imshow(pred_decoded)
plt.axis("off")
plt.show()

Visualization of the predicted Scene Segmentation.

What’s next?

With an inference endpoint up and running, one could leverage its full power by exploring other features that are important for a machine learning product, including AutoScaling, model monitoring with Human-in-the-loop (HITL) using Amazon Augmented AI (A2I), and incremental modeling iteration.

Clean up

Make sure that you delete the following resources to prevent any additional charges:

  1. Amazon SageMaker endpoint.
  2. Amazon SageMaker endpoint configuration.
  3. Amazon SageMaker model.
  4. Amazon ECR.
  5. Amazon S3 buckets.

Conclusion

This article presented an end-to-end demonstration of deploying fast.ai-trained PyTorch models on TorchServe eager model and host in Amazon SageMaker endpoint. You can use this repository as a template to deploy your own fast.ai models. This approach eliminates the self-maintaining effort to build and manage a customized inference server, which helps you to speed up the process from training a cutting-edge deep learning model to its online application in real-world at scale.

If you have questions please create an issue or submit pull request on the GitHub repository.

Reference

Baichuan Sun

Baichuan Sun

Dr. Baichuan Sun, currently serving as a Sr. AI/ML Solution Architect at AWS, focuses on generative AI and applies his knowledge in data science and machine learning to provide practical, cloud-based business solutions. With experience in management consulting and AI solution architecture, he addresses a range of complex challenges, including robotics computer vision, time series forecasting, and predictive maintenance, among others. His work is grounded in a solid background of project management, software R&D, and academic pursuits. Outside of work, Dr. Sun enjoys the balance of traveling and spending time with family and friends, reflecting a commitment to both his professional growth and personal well-being.

Calvin Wang

Calvin Wang

Calvin is a Data Scientist at AWS AI/ML. He holds a B.S. in Computer Science from UC Santa Barbara and loves using machine learning to build cool stuff.

Eden Duthie

Eden Duthie

AWS Professional Service Machine Learning lead for the APJC region.

Kavitha Rajendran

Kavitha Rajendran

Kavitha is a Data Scientist at AWS AI/ML. She holds a MS in Computer Science from The University of Texas at Dallas.