Hacking TorchServe and Triton Inference Server for MONAI and Medical Imaging
Authors:
- Dr. Vikash Gupta, Center for Augmented Intelligence in Imaging, Mayo Clinic Florida
- Jiahui Guan, Senior Solutions Architect, NVIDIA
GitHub URL: https://github.com/vikashg/monai-inference-demo
MONAI has become the de facto standard in medical image AI, offering tools for model training and clinical integration. It provides out-of-the-box tools for the preprocessing and postprocessing of radiological images. MONAI features a broad selection of predefined neural network architectures and training routines based on PyTorch.
The MONAI Model Zoo offers access to pre-trained models based on peer-reviewed research, allowing developers to either fine-tune these models or use them to perform inference. In the latest 1.3 release, the MONAI model zoo introduced a Pythonic API, simplifying the processes for downloading and fine-tuning models.
On the opposite end of the spectrum is MONAI Deploy Express, which provides a comprehensive deployment solution. It uses a MONAI Application Package (MAP) to create application bundles that can be deployed. These MAPs can take a DICOM input and generate an appropriate DICOM image as output. MONAI Deploy Express is an open-source solution that can be integrated into hospital ecosystems, serving as an inference server and a DICOM router.
MONAI effectively covers two opposite ends of the workflow, with MONAI core at one end, focusing on data preprocessing and model training, and MONAI Deploy Express at the other, offering a comprehensive suite of tools for a complete clinical deployment. However, research labs often find themselves in need of a middle ground — a solution that bridges the gap between operating a full-fledged MONAI Deploy Express instance and writing new code for each inference task. The ideal solution is a local inference server that can be consistently referenced for the same purposes. The advantages of having a dedicated model inference server include:
- A common endpoint that can be invoked for inference tasks
- A common directory to maintain models and apps
In this post, we will explore the following:
- How to write a handler function for a MONAI-based application
- How to write a test function for the aforementioned handler
- How to deploy a MONAI model using a torchserve inference server
- How to deploy a MONAI model using a Triton inference server
- Using a REST API for executing inference
- A comparison between the two model servers
Prerequisites and Assumptions
We’ll be assuming familiarity with a few key concepts:
- MONAI Transforms: Knowledge of how to use MONAI for transforming medical imaging data
- REST API: Understanding the basics of REST API design for web services interaction
- Deep Learning Concepts: A grasp of fundamental deep learning principles
All of the necessary code, sample data, and models for this guide will be made available through a GitHub Repository.
Creating the ModelHandler
First, we’ll break down the code for Model Handler. This handler consists of four main functions: preprocess, post-process, inference, and handle.
Here, we’ll break down the code into understandable chunks.
- Initializing the ModelHandler, which inherits from the BaseHandler
class ModelHandler(BaseHandler):
def __init__(self):
self._context = None
self.initialized = False
self.explain = False
self.target = 0
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2. Write the data preprocessing function
def preprocess(self, data_fn):
transforms = Compose([LoadImage(image_only=True),
EnsureChannelFirst(),
Resize(spatial_size=(256, 256,24)),
ScaleIntensityRange(a_min=20, a_max=1200, b_min=0, b_max=1, clip=True),
AddChannel()])
fn = data_fn[0]['filename'].decode()
img_fullname = os.path.join(input_dir, fn)
data = []
batch_size = 1
for i in range(batch_size):
tmp = {}
print(img_fullname)
tmp["data"] = transforms(img_fullname)
data.append(tmp)
return data
The preprocessing function, as it is aptly called, takes a filename data_fn. It loads the data from the specified data directory and performs preprocessing using the MONAI transform chain. In this particular case, the Nifti image is loaded and resized to a spatial size of 256 X 256 x 24, and the intensity values are scaled between 0 and 1. We then generate a list of data. Since we pass only a single file name, it is important that we hardcode the batch_size to 1.
3. Post-processing function
def postprocess(self, inference_output):
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
postprocess_output = post_trans(inference_output)
SaveImage(output_dir, output_postfix='seg', output_ext='.nii.gz')(postprocess_output[0])
return [1]
Like the preprocessing function, the post-processing function applies the post-processing transforms to the model’s output. It then saves the output to the predefined data directory. If everything goes well, it will return success.
4. Inference Function
def inference(self, data, *args, **kwargs):
with torch.no_grad():
marshalled_data = data.to(self.device)
results = self.model(marshalled_data, *args, *kwargs)
return results
Create a Test Function
Before creating a model server, we must ensure that the model performs inference as expected. We’ll utilize the previously created ModelHandler and create a test function to validate the workflow.
from handler import ModelHandler
from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext
from pathlib import Path
import os
MODEL_PT_FILE='traced_segres_model.pt'
CURR_FILE_PATH = Path(__file__).parent.absolute()
EXAMPLE_ROOT_DIR=CURR_FILE_PATH
TEST_DATA=os.path.join(CURR_FILE_PATH, 'test.nii.gz')
def test_segresnet(batch_size=1):
handler = ModelHandler()
print(EXAMPLE_ROOT_DIR.as_posix())
ctx = MockContext(model_pt_file = MODEL_PT_FILE,
model_dir= EXAMPLE_ROOT_DIR.as_posix(),
model_file = None,)
handler.initialize(ctx)
handler.context = ctx
handler.handle(TEST_DATA, ctx)
if __name__ == '__main__':
test_segresnet()
Create a TorchServe endpoint
To set up a TorchServe endpoint with your handler, follow these steps:
- Create a torch model archive (.mar) using the torchscript model (traced_segres_model.pt) and model handler you wrote earlier. Specify the model name, serialized file, handler path, and version.
torch-model-archiver - model-name segresnet - serialized-file ./traced_segres_model.pt - handler ./handler.py -v 1.0
2. Move the model to the model_store folder
mv segresnet.mar model_store
3. Start torchserve. Specify the model store directory, the model to deploy, and the TorchServe configuration file.
torchserve - start - model-store model_store/ - models segresnet=segresnet.mar - ts-config ./config.properties
4. Test the torchserve is running properly by sending a request to the health endpoint.
curl http://localhost:8080/ping
Expected Response:
{
"status": "Healthy"
}
5. List all the models deployed on torchserve model server by executing the following command:
curl http://localhost:8081/models?
Expected Output:
{
"models": [
{
"modelName": "breast",
"modelUrl": "breast.mar"
},
{
"modelName": "segresnet",
"modelUrl": "segresnet.mar"
}
]
6. Make a prediction request by passing the filename as data
curl -d "filename=test.nii.gz" http://127.0.0.1:8080/predictions/segresnet
If the prediction is successful, it will return ‘1’. You can now see the AI output in the specified output directory.
This process allows for deploying and managing multiple models on TorchServe, enabling a flexible and scalable serving environment for your AI applications.
Triton Inference Server-based model server for MONAI models
Triton Inference Serve enables the deployment of any AI model. It supports multiple libraries like TensorRT, TensorFlow, PyTorch, ONNX, and OpenVINO. In this blog, we will demonstrate MONAI models being deployed on the Triton Inference Server. As MONAI supports PyTorch, we will deploy a PyTorch model on this inference server.
Steps:
- Convert the model to torchscript using the following code.
- Write the config.pbtxt file needed for the triton inference server
- Similar to the torchserve, we need to write handler for image client.
Code to convert the PyTorch model to the traced torch script model. The code can be downloaded from the Github link.
from model_del import ModelDefinition
import torch
model_def = ModelDefinition(model_name='SegResNet')
model = model_def.get_model()
model_fn = './model/SegResNet/best_metric_model.pth' # Filename
model.load_state_dict(torch.load(model_fn))
x = torch.zeros(1, 1, 256, 256, 24)
traced_model = torch.jit.trace(model, x)
traced_model.save('./model/SegResNet/traced_segres_model.pt')
This model is saved in a directory.
Now, we need to write the config.pbtxt file.
name: "lv_segmentation"
platform: "pytorch_libtorch"
max_batch_size: 2
input [
{
name: "input__0"
data_type: TYPE_FP32
dims: [1, 256, 256, 24 ]
}
]
output [
{
name: "output__0"
data_type: TYPE_FP32
dims: [1, 256, 256, 24]
}
]
Explanation of the config.pbtxt file
The field name corresponds to the name of the model. The field platform refers to the deep learning library used. In the present case, we used PyTorch to develop the model and thus the value pytorch_libtorch. The following field max_batch_size is self-explanatory as the maximum number of images that can be inferred at once. The last two fields input and output has three sub-fields: name, data_type and dims. The name refers to the input/output node of the model. In most cases (especially for PyTorch), the nodes names by default are input__0 and output__0. So, we encourage the users to use these values in that field. If it fails, the users should do further investigation. TYPE_FP32 refers to the data_type. The last field dims corresponds to the size of input image and the expected output (image). In this particular case, we are demonstrating a segmentation model and so the input and output dims are the same.
More information about the config.pbtxt file is available at the official Github for Triton: https://github.com/triton-inference-server/server/blob/main/docs/getting_started/quickstart.md
Deploying on Triton Inference Server
To deploy the models, arrange your models in a directory structure more like this:
These are two models: breast_density and lv_segmentation. The model file model.pt is inside the folder named 1 for both models. The 1, in this case, refers to the model version. Once the directory is set, the triton server can be started using the following command:
docker run - gpus=1 -p 8000:8000 -p 8001:8001 -p 8002:8002 -v ${PWD}/model_repository/models:/models nvcr.io/nvidia/tritonserver:23.12-py3 tritonserver - model-repository=/models
The above command maps the model_repository/models directory to the tritonserver. For the first run, it will download the docker images nvcr.io/nvidia/tritonserver:23.12-py3. Make sure that you are downloading the appropriate tritonserver image. More information is available at https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver
Writing the image-client
Now, we can focus on writing the image-client.py. The image-client file mainly contains two functions: preprocess and post_process. The preprocess function looks as follows. One important thing to note in this function is that, though we use MONAI transform chain, it converts the output to a numpy array. This is because triton-inference-server can make predictions on numpy arrays.
def preprocess(img_path="MR.nii.gz"):
transforms = Compose([LoadImage(image_only=True),
EnsureChannelFirst(),
Resize(spatial_size=(256, 256, 24)),
ScaleIntensityRange(a_min=20, a_max=1200, b_min=0, b_max=1, clip=True)])
img_tensor = transforms(img_path)
results_np = np.expand_dims(img_tensor.numpy(), axis=0)
return results_np
In this case, we also need to post-process the output and save it as a Nifti or a DICOM-RT file. In this case, we are saving it as a Nifti file. In this function, we use SimpleITK to save a nifti file where we get the “meta” information for the image for a reference image (most likely the input image).
def post_transform(inference_output, out_dir='./', ref_image=None):
post_trans = Compose([Activations(sigmoid=True),
AsDiscrete(threshold=0.5), ])
image_itk = sitk.GetImageFromArray(np.transpose(postprocess_output, [2, 1, 0]))
image_itk.SetSpacing(ref_image.GetSpacing())
image_itk.SetOrigin(ref_image.GetOrigin())
image_itk.SetDirection(ref_image.GetDirection())
sitk.WriteImage(image_itk, os.path.join(out_dir,
"segmentation.nii.gz"))
Finally, the driver function for the image-client is written as:
def main():
img_path = "/data/MR.nii.gz"
transformed_image = preprocess(img_path=img_path)
client = httpclient.InferenceServerClient(url="localhost:8000")
inputs = httpclient.InferInput("input__0", transformed_image.shape, datatype="FP32")
inputs.set_data_from_numpy(transformed_image, binary_data=True)
outputs = httpclient.InferRequestedOutput("output__0", binary_data=True, class_count=0)
results = client.infer(model_name="lv_segmentation", inputs=[inputs], outputs=[outputs])
inference_output = results.as_numpy("output__0")
ref_img_fn = './tmp/MR/MR_preprocessed.nii.gz'
reader = sitk.ImageFileReader()
reader.SetFileName(ref_img_fn)
ref_image = reader.Execute()
post_transform(inference_output, out_dir='./', ref_image = ref_image)
As before, the complete program is available on the GitHub page. The image-client is called from the command line as
python image-client.py
The image-client is similar to the model_handler. The main difference between model_handler for TorchServe is that for triton inference server, the input to the model should be in a numpy format as opposed to the meta-dictionary format for torchserve.
Conclusion
MONAI has established itself as a go-to standard for processing medical images. MONAI-Deploy-App-SDK and MONAI Deploy are components that provide clinical integration tools for model deployment. However, we believe there is a huge valley of possibilities between MONAI Deploy’s clinical integration and deployment strategies, and no deployment is available in MONAI. Here, we are taking the PyTorch model trained using MONAI and deploying it using model deployment servers like torchserve and Triton Inference Server.
We would like to thank Dr. Mutlu Demirer, Dr. Barbaros Selnur Erdal, and Dr. Richard D. White from the Center for Augmented Intelligence in Imaging at Mayo Clinic, Florida, for their guidance and support on this project.