TorchScript Type Annotation

TorchScript is statically typed, which means that variable types must be explicitly defined at compile time. For this reason, it may be necessary to annotate variable types in your model so that every local variable has a static type and every function has a statically typed signature. For more information on annotations in TorchScript, see this tutorial.

TorchScript

If we consider our simple Cifar10ResNet example, let’s add a simple normalisation function to our forward function:

def normalize(self, input, mean, std):
   """
   This method normalizes the values in input based on mean and std.
   :param input: a torch.Tensor of the size [batch x 3 x H x W]
   :param mean: A tuple of float values that represent the mean of
                the r,g,b chans e.g. (0.5, 0.5, 0.5)
   :param std: A tuple of float values that represent the std of the
               r,g,b chans e.g. (0.5, 0.5, 0.5)
   :return: a torch.Tensor that has been normalised
   """
   input[:, 0, :, :] = (input[:, 0, :, :] - mean[0]) / std[0]
   input[:, 1, :, :] = (input[:, 1, :, :] - mean[1]) / std[1]
   input[:, 2, :, :] = (input[:, 2, :, :] - mean[2]) / std[2]
   return input

def forward(self, input):
   """
   :param input: A torch.Tensor of size 1 x 3 x H x W representing the input image
   :return: A torch.Tensor of size 1 x 1 x H x W of zeros or ones
   """

   # Normalise the input tensor
   mean = (0.5, 0.5, 0.5)
   std = (0.5, 0.5, 0.5)
   input = self.normalize(input, mean, std)

   modelOutput = self.model.forward(input)
   modelLabel = int(torch.argmax(modelOutput[0]))

   plane = 0
   if modelLabel == plane:
      output = torch.ones(1, 1, input.shape[2], input.shape[3])
   else:
      output = torch.zeros(1, 1, input.shape[2], input.shape[3])
   return output

If you try to convert this model to TorchScript, you will get the following TorchScript error:

_images/type-annotation-01.png

This error indicates that when converting the model to TorchScript, the compiler assumed that the parameter mean was of type tensor, but this assumption was incorrect and caused an error. To overcome this, we need to include some annotations in the normalize() function as follows:

 def normalize(self, input, mean, std):
    """
    This method normalizes the values in input based on mean and std.
    :param input: a torch.Tensor of the size [batch x 3 x W x H]
    :param mean: A tuple of float values that represent the mean of the
                   r,g,b chans e.g. (0.5, 0.5, 0.5)
    :param std: A tuple of float values that represent the std of the
                r,g,b chans e.g. (0.5, 0.5, 0.5)
    :return: a torch.Tensor that has been normalised
    """
    # type: (Tensor, Tuple[float, float, float], Tuple[float, float, float]) -> Tensor
    input[:, 0, :, :] = (input[:, 0, :, :] - mean[0]) / std[0]
    input[:, 1, :, :] = (input[:, 1, :, :] - mean[1]) / std[1]
    input[:, 2, :, :] = (input[:, 2, :, :] - mean[2]) / std[2]
    return input

Now that these annotations have been added, this model can be successfully converted to the TorchScript format without error.