Deploy models into production (advanced)

Audience: Machine learning engineers optimizing models for enterprise-scale production environments.


Export your model with torch.export

torch.export is the recommended way to capture PyTorch models for deployment in production environments. It produces a clean intermediate representation with strong soundness guarantees, making models suitable for inference optimization and cross-platform deployment. You can export any LightningModule using the torch.export.export() API.

import torch
from torch.export import export

class SimpleModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(in_features=64, out_features=4)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))


# create the model and example input
model = SimpleModel()
example_input = torch.randn(1, 64)

# export the model
exported_program = export(model, (example_input,))

# save for use in production environment
torch.export.save(exported_program, "model.pt2")

It is recommended that you install the latest supported version of PyTorch to use this feature without limitations. Once you have the exported model, you can load and run it:

inp = torch.rand(1, 64)
loaded_program = torch.export.load("model.pt2")
output = loaded_program.module()(inp)

For more complex models, you can also export specific methods by creating a wrapper:

class LitMCdropoutModel(L.LightningModule):
    def __init__(self, model, mc_iteration):
        super().__init__()
        self.model = model
        self.dropout = nn.Dropout()
        self.mc_iteration = mc_iteration

    def predict_step(self, batch, batch_idx):
        # enable Monte Carlo Dropout
        self.dropout.train()

        # take average of `self.mc_iteration` iterations
        pred = [self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]
        pred = torch.vstack(pred).mean(dim=0)
        return pred


model = LitMCdropoutModel(...)
example_batch = torch.randn(32, 10)  # example input

# Export the predict_step method
exported_program = torch.export.export(
    lambda batch, idx: model.predict_step(batch, idx),
    (example_batch, 0)
)
torch.export.save(exported_program, "mc_dropout_model.pt2")