TorchScript Variables and Functions

There are several things that should be noted in relation to variables and functions when writing a model that will be converted to TorchScript. Below, we have listed some of those that arise most frequently. For more details on each of these points and others that may arise when converting your own model to TorchScript, please see the most recent TorchScript documentation as well as the TorchScript documentation specific to your PyTorch version.

External Libraries

Many of Python’s built-in functions are supported in TorchScript, along with Python’s ‘math’ module. However, no other Python modules are supported. This means that any part of your model that uses other Python modules (eg. NumPy or OpenCV ) will need to be rewritten using only functions that TorchScript supports.

Class Attribute Initialisation

Class attributes must be declared in the init function. For example, this attribute declared outside of the init function:

 class Cifar10ResNet(torch.nn.Module):
    """
    This class is a wrapper for our custom ResNet model.
    """
    def __init__(self):

       super(Cifar10ResNet, self).__init__()
       self.model = ResNet(pretrained=True)
       self.model.eval()

    def forward(self, input):
       self.label = 1
       return torch.ones(1, 1, input.shape[2], input.shape[3])

will cause an error when trying to convert this model to TorchScript. This is the correct definition process:

 class Cifar10ResNet(torch.nn.Module):
    """
    This class is a wrapper for our custom ResNet model.
    """
    def __init__(self):

       super(Cifar10ResNet, self).__init__()
       self.model = ResNet(pretrained=True)
       self.model.eval()
       self.label = 1

    def forward(self, input):
       return torch.ones(1, 1, input.shape[2], input.shape[3])

Static Variables

Unlike Python, all variables must have a single static type in TorchScript. For example, this forward function will error when converted to TorchScript since the type of variable r changes depending on the if statement:

def forward(self, input):
   if self.label == 1:
      r = 1
   else:
      r =  torch.zeros(1, 1, input.shape[2], input.shape[3])
   return input

This function should be rewritten as:

def forward(self, input):
   if self.label == 1:
      r = torch.ones(1, 1, input.shape[2], input.shape[3])
   else:
      r =  torch.zeros(1, 1, input.shape[2], input.shape[3])
   return input

Similarly, all functions must be defined so that their return variable type is clear and does not change. For example, consider a forward function defined as:

def forward(self, input):
   if self.label == 1:
      return 1
   else:
      return torch.zeros(1, 1, input.shape[2], input.shape[3])

The return type of this function can be either an integer or a tensor, depending on the if statement, and will error when converted to TorchScript. This function should be redefined, for example as:

def forward(self, input):
   if self.label == 1:
      return torch.ones(1, 1, input.shape[2], input.shape[3])
   else:
      return torch.zeros(1, 1, input.shape[2], input.shape[3])

None and Optional Types

Since TorchScript is statically typed, type annotation may be needed to ensure that variable types are correctly inferred when using None. For example, having an assignment x = None in your model will cause x to be inferred as NoneType, when it might actually be an Optional type. In this case, type annotation with x: Optional[int] = None can be used to clarify that x is indeed Optional, and can have either type integer or None.

The following code snippet indicates how to annotate an Optional class attribute in the init() function. It also demonstrates that in order to refine class attributes with Optional type outside of the init(), they must be assigned to a local variable to be refined.

 import torch
 import torch.nn as nn
 from resnet import ResNet
 from typing import Optional

 class ObjectDetection(torch.nn.Module):
    label: Optional[int]
    """
    This model detects objects in an image. If an object is detected, an image
    filled with the object label is returned. If no object is detected, an image
    of zeros is returned.
    """
    def __init__(self):
       super(ObjectDetection, self).__init__()
       self.model = ResNet(pretrained=True)
       self.model.eval()
       label = None
       self.label = label

    def forward(self, input):
       """
       This forward function updates self.label using the object label returned by the model.
       It also returns an image filled with the detected object label. This is an image
       of zeros if no object has been detected.
       """

       objectLabel = 0 # objectLabel is an integer
       output = objectLabel*torch.ones((1, 1, input.shape[2], input.shape[3]))

       # To refine self.label, its value must first be assigned to an Optional[int] variable
       # which can then be updated and reassigned to self.label
       label = self.label
       if objectLabel == 0:
          label = None
       else:
          label = objectLabel
       self.label = label

       return output

For more details on Optional type annotation, see the official TorchScript Language Reference pages version 1 and version 2.