TorchScript Type Annotation
TorchScript is statically typed, which means that variable types must be explicitly defined at compile time. For this reason, it may be necessary to annotate variable types in your model so that every local variable has a static type and every function has a statically typed signature. For more information on annotations in TorchScript, see this tutorial.
TorchScript
If we consider our simple Cifar10ResNet example, let’s add a simple normalisation function to our forward function:
def normalize(self, input, mean, std):
"""
This method normalizes the values in input based on mean and std.
:param input: a torch.Tensor of the size [batch x 3 x H x W]
:param mean: A tuple of float values that represent the mean of
the r,g,b chans e.g. (0.5, 0.5, 0.5)
:param std: A tuple of float values that represent the std of the
r,g,b chans e.g. (0.5, 0.5, 0.5)
:return: a torch.Tensor that has been normalised
"""
input[:, 0, :, :] = (input[:, 0, :, :] - mean[0]) / std[0]
input[:, 1, :, :] = (input[:, 1, :, :] - mean[1]) / std[1]
input[:, 2, :, :] = (input[:, 2, :, :] - mean[2]) / std[2]
return input
def forward(self, input):
"""
:param input: A torch.Tensor of size 1 x 3 x H x W representing the input image
:return: A torch.Tensor of size 1 x 1 x H x W of zeros or ones
"""
# Normalise the input tensor
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
input = self.normalize(input, mean, std)
modelOutput = self.model.forward(input)
modelLabel = int(torch.argmax(modelOutput[0]))
plane = 0
if modelLabel == plane:
output = torch.ones(1, 1, input.shape[2], input.shape[3])
else:
output = torch.zeros(1, 1, input.shape[2], input.shape[3])
return output
If you try to convert this model to TorchScript, you will get the following TorchScript error:
This error indicates that when converting the model to TorchScript, the compiler assumed that the parameter mean was of type tensor, but this assumption was incorrect and caused an error. To overcome this, we need to include some annotations in the normalize() function as follows:
def normalize(self, input, mean, std):
"""
This method normalizes the values in input based on mean and std.
:param input: a torch.Tensor of the size [batch x 3 x W x H]
:param mean: A tuple of float values that represent the mean of the
r,g,b chans e.g. (0.5, 0.5, 0.5)
:param std: A tuple of float values that represent the std of the
r,g,b chans e.g. (0.5, 0.5, 0.5)
:return: a torch.Tensor that has been normalised
"""
# type: (Tensor, Tuple[float, float, float], Tuple[float, float, float]) -> Tensor
input[:, 0, :, :] = (input[:, 0, :, :] - mean[0]) / std[0]
input[:, 1, :, :] = (input[:, 1, :, :] - mean[1]) / std[1]
input[:, 2, :, :] = (input[:, 2, :, :] - mean[2]) / std[2]
return input
Now that these annotations have been added, this model can be successfully converted to the TorchScript format without error.