TorchScript and Inheritance

TorchScript does not support inheritance, which means that you will have to ensure that your model classes do not have any subclass definitions. Because one model cannot be defined as a subclass of another, you can rewrite your models so that they are independent of each other. Alternatively, if you need to call one model from the forward function of another, you can define one of your models as an attribute of another, FirstModel.SecondModel, for instance.

TorchScript

Consider an object detection model that is defined with a Cifar10ResNet subclass as follows in PyTorch:

import torch
import torch.nn as nn
from resnetSmall import ResNet

class ObjectDetection(torch.nn.Module):
   def __init__(self):
      """
      This class uses a ResNet to detect 100 objects in an image. The forward function returns
      ones if an airplane is detected, and zeros otherwise.
      """
      super(ObjectDetection, self).__init__()
      self.model = ResNet()
      self.model.load_state_dict(torch.load('resnet_100.chkpt'))
      self.model.eval()

   def forward(self, input):
      """
      This forward function returns an image of ones if an airplane is detected in the image,
      and an image of zeros otherwise.
      """
      modelOutput = self.model.forward(input)
      modelLabel = int(torch.argmax(modelOutput[0]))

      if modelLabel == 1:
         output = torch.ones(1, 1, input.shape[2], input.shape[3])
      else:
         output = torch.zeros(1, 1, input.shape[2], input.shape[3])
      return output

class Cifar10ResNet(ObjectDetection):
   """
   This class detects 10 cifar objects in an image. It uses the forward function
   from the ObjectDetection class and returns an image of ones if an airplane is
   detected, and an image of zeros otherwise.
   """
   def __init__(self):
      super(Cifar10ResNet, self).__init__()
      self.model = ResNet()
      self.model.load_state_dict(torch.load('resnet_cifar10.chkpt'))
      self.model.eval()

   def forward(self, input):
      return super().forward(input)

The Cifar10ResNet model can be used for inference with the following code in PyTorch:

my_model = Cifar10ResNet()
output_image = my_model.forward(input_image)

However, this model is not convertible to TorchScript because of the class inheritance. To convert to TorchScript, the Cifar10ResNet model can be defined independently of ObjectDetection, inheriting from torch.nn.Module with its own forward function as follows:

 class Cifar10ResNet(torch.nn.Module):
    """
    This class detects 10 cifar objects in an image. Its forward function
    returns an image of ones if an airplane is detected, and an image of
    zeros otherwise.
    """
    def __init__(self):
       super(Cifar10ResNet, self).__init__()
       self.model = ResNet()
       self.model.load_state_dict(torch.load('resnet_cifar10.chkpt'))
       self.model.eval()

    def forward(self, input):

       modelOutput = self.model.forward(input)
       modelLabel = int(torch.argmax(modelOutput[0]))

       if modelLabel == 1:
          output = torch.ones(1, 1, input.shape[2], input.shape[3])
       else:
          output = torch.zeros(1, 1, input.shape[2], input.shape[3])
       return output

This model can now be converted to TorchScript as follows:

my_model = Cifar10ResNet()
module = torch.jit.script(my_model)
module.save('cifar10_resnet.pt')

On the other hand, in order to use one model in the same forward function as another, one model can be defined as an attribute of another. For example, the ObjectDetection class can instead be defined as:

 class ObjectDetection(torch.nn.Module):
    def __init__(self):
       """
       This class is a wrapper around the Cifar10ResNet  model.
       """
       super(ObjectDetection, self).__init__()
       self.Cifar10ResNet = Cifar10ResNet()

    def forward(self, input):
       output = self.Cifar10ResNet.forward(input)
       return output


 class Cifar10ResNet(torch.nn.Module):
    """
    This class is a wrapper for our custom ResNet model.
    """
    def __init__(self):
       super(Cifar10ResNet, self).__init__()
       self.model = ResNet(pretrained=True)
       self.model.eval()

    def forward(self, input):
       """
       :param input: A torch.Tensor of size 1 x 3 x H x W representing the input image
       :return: A torch.Tensor of size 1 x 1 x H x W of zeros or ones
       """
       modelOutput = self.model.forward(input)
       modelLabel = int(torch.argmax(modelOutput[0]))

       if modelLabel == 1:
          output = torch.ones(1, 1, input.shape[2], input.shape[3])
       else:
          output = torch.zeros(1, 1, input.shape[2], input.shape[3])
       return output

In this case, Cifar10ResNet is defined as an attribute of the ObjectDetection class, and its forward function can be called within the ObjectDetection forward function. The ObjectDetection model can be converted to TorchScript using the following code:

my_model = ObjectDetection()
module = torch.jit.script(my_model)
module.save('object_detection.pt')

In the section Accessing Attributes of Attributes, we will discuss how to use a custom knob in Nuke to access an attribute of a model defined as an attribute of another model.