TorchScript Variables and Functions
There are several things that should be noted in relation to variables and functions when writing a model that will be converted to TorchScript. Below, we have listed some of those that arise most frequently. For more details on each of these points and others that may arise when converting your own model to TorchScript, please see the most recent TorchScript documentation as well as the TorchScript documentation specific to your PyTorch version.
External Libraries
Many of Python’s built-in functions are supported in TorchScript, along with Python’s ‘math’ module. However, no other Python modules are supported. This means that any part of your model that uses other Python modules (eg. NumPy or OpenCV ) will need to be rewritten using only functions that TorchScript supports.
Class Attribute Initialisation
Class attributes must be declared in the init function. For example, this attribute declared outside of the init function:
class Cifar10ResNet(torch.nn.Module):
"""
This class is a wrapper for our custom ResNet model.
"""
def __init__(self):
super(Cifar10ResNet, self).__init__()
self.model = ResNet(pretrained=True)
self.model.eval()
def forward(self, input):
self.label = 1
return torch.ones(1, 1, input.shape[2], input.shape[3])
will cause an error when trying to convert this model to TorchScript. This is the correct definition process:
class Cifar10ResNet(torch.nn.Module):
"""
This class is a wrapper for our custom ResNet model.
"""
def __init__(self):
super(Cifar10ResNet, self).__init__()
self.model = ResNet(pretrained=True)
self.model.eval()
self.label = 1
def forward(self, input):
return torch.ones(1, 1, input.shape[2], input.shape[3])
Static Variables
Unlike Python, all variables must have a single static type in TorchScript. For example, this forward function will error when converted to TorchScript since the type of variable r changes depending on the if statement:
def forward(self, input):
if self.label == 1:
r = 1
else:
r = torch.zeros(1, 1, input.shape[2], input.shape[3])
return input
This function should be rewritten as:
def forward(self, input):
if self.label == 1:
r = torch.ones(1, 1, input.shape[2], input.shape[3])
else:
r = torch.zeros(1, 1, input.shape[2], input.shape[3])
return input
Similarly, all functions must be defined so that their return variable type is clear and does not change. For example, consider a forward function defined as:
def forward(self, input):
if self.label == 1:
return 1
else:
return torch.zeros(1, 1, input.shape[2], input.shape[3])
The return type of this function can be either an integer or a tensor, depending on the if statement, and will error when converted to TorchScript. This function should be redefined, for example as:
def forward(self, input):
if self.label == 1:
return torch.ones(1, 1, input.shape[2], input.shape[3])
else:
return torch.zeros(1, 1, input.shape[2], input.shape[3])
None and Optional Types
Since TorchScript is statically typed, type annotation may be needed to ensure that variable types are correctly inferred when using None. For example, having an assignment x = None in your model will cause x to be inferred as NoneType, when it might actually be an Optional type. In this case, type annotation with x: Optional[int] = None can be used to clarify that x is indeed Optional, and can have either type integer or None.
The following code snippet indicates how to annotate an Optional class attribute in the init() function. It also demonstrates that in order to refine class attributes with Optional type outside of the init(), they must be assigned to a local variable to be refined.
import torch
import torch.nn as nn
from resnet import ResNet
from typing import Optional
class ObjectDetection(torch.nn.Module):
label: Optional[int]
"""
This model detects objects in an image. If an object is detected, an image
filled with the object label is returned. If no object is detected, an image
of zeros is returned.
"""
def __init__(self):
super(ObjectDetection, self).__init__()
self.model = ResNet(pretrained=True)
self.model.eval()
label = None
self.label = label
def forward(self, input):
"""
This forward function updates self.label using the object label returned by the model.
It also returns an image filled with the detected object label. This is an image
of zeros if no object has been detected.
"""
objectLabel = 0 # objectLabel is an integer
output = objectLabel*torch.ones((1, 1, input.shape[2], input.shape[3]))
# To refine self.label, its value must first be assigned to an Optional[int] variable
# which can then be updated and reassigned to self.label
label = self.label
if objectLabel == 0:
label = None
else:
label = objectLabel
self.label = label
return output
For more details on Optional type annotation, see the official TorchScript Language Reference pages version 1 and version 2.