Putting it all Together: Example 1

Now that we have discussed the topics that you should bear in mind when writing a model and converting it from PyTorch to a .cat file, let’s update our simple Cifar10ResNet model to a more advanced version.

TorchScript

Let’s define our Cifar10ResNet model as follows:

import torch
import torch.nn as nn
import torch.nn.functional as F
from resnetSmall import ResNet, ResidualBlock

class Cifar10ResNet(torch.nn.Module):
    """
    This is a wrapper class for our custom ResNet model. Its purpose is to
    preprocess the input tensors and ensure that all tensors are created on
    the correct device with the correct dtype. It also defines the variable
    userLabel as an attribute that can be controlled by a custom knob in Nuke.
    """
    def __init__(self, userLabel = 0):
        """
        :param userLabel: This int controls which object is detected by the
                          Cifar10ResNet model and can be linked to an enumeration
                          knob in Nuke.
        """
        super(Cifar10ResNet, self).__init__()
        self.model = ResNet(pretrained=True)
        self.model.eval()
        self.userLabel = userLabel

Note that this class defines the attribute userLabel which can be controlled by a custom knob in Nuke.

Let’s define our model’s normalize() and forward() functions as follows:

def normalize(self, input, mean, std):
      """
      This method normalizes the values in input based on mean and std.
      :param input: a torch.Tensor of the size [batch x 3 x H x W]
      :param mean: A tuple of float values that represent the mean of the
                           r,g,b chans e.g. [0.5, 0.5, 0.5]
      :param std: A tuple of float values that represent the std of the
                        r,g,b chans e.g. [0.5, 0.5, 0.5]
      :return: a torch.Tensor that has been normalised
      """
      # type: (Tensor, Tuple[float, float, float], Tuple[float, float, float]) -> Tensor
      input[:, 0, :, :] = (input[:, 0, :, :] - mean[0]) / std[0]
      input[:, 1, :, :] = (input[:, 1, :, :] - mean[1]) / std[1]
      input[:, 2, :, :] = (input[:, 2, :, :] - mean[2]) / std[2]
      return input

def forward(self, input):
      """
      The forward function for this model will normalize the input and then pass it
      to the Resnet forward function. It will compare the returned label to userLabel label and
      return either an all ones or all zeros tensor.
      :param input: A torch.Tensor of size 1 x 3 x H x W representing the input image
      :return: A torch.Tensor of size 1 x 1 x H x W of zeros or ones
      """
      # Determine which device all tensors should be created on
      if(input.is_cuda):
         device = torch.device('cuda')
      else:
         device = torch.device('cpu')

      # Normalise the input tensor
      mean = (0.5, 0.5, 0.5)
      std = (0.5, 0.5, 0.5)
      input = self.normalize(input, mean, std)

      modelOutput = self.model.forward(input)
      modelLabel = int(torch.argmax(modelOutput[0]))

      # Check if the detected object is the same as userLabel
      if modelLabel == self.userLabel:
         # Ensure output is created on the correct device with the correct dtype
         output = torch.ones((1, 1, input.shape[2], input.shape[3]), dtype = input.dtype, device=device)
      else:
         # Ensure output is created on the correct device with the correct dtype
         output = torch.zeros((1, 1, input.shape[2], input.shape[3]), dtype = input.dtype, device=device)
      return output

Note that the normalize() function is applied to the input tensor to ensure tensor values are in the range expected by the model. When defining the normalize() function, we use annotations to clarify what the input and output parameter types are. The forward function accepts an input tensor of size 1 x 3 x H x W, and outputs a tensor of size 1 x 1 x H x W. In the forward function, we also ensure that all tensors are created on the correct device, with the correct dtype.

We can convert this model to a TorchScript file using the following code

resnet = Cifar10ResNet()
module = torch.jit.script(resnet)
module.save('cifar10_resnet.pt')

CatFileCreator

In the CatFileCreator node in Nuke, set the default knobs as follows:

_images/simple-example-01.png

Next, add an enumeration knob by dragging and dropping it to the top of the CatFileCreator properties panel. Fill in the enumeration knob values as follows:

_images/custom-knobs-01.png

This will create an enumeration knob in your CatFileCreator node. Next, click the Create .cat file and Inference to create your .cat file and prepopulated Inference node.

Inference

Opening the new Inference node’s properties panel, you will see the knob values have been prepopulated and the Detect enumeration knob has been created. The Inference node can be connected to an image with red, green and blue channels to get the output result:

_images/custom-knobs-02.png

Note that the Cifar10ResNet model was trained in sRGB space. Therefore, to ensure our input image is in the correct colour space, we have checked the Raw Data knob in the Read node since this image was originally rendered in the sRGB space. Alternatively, a colour space node can be added before the Inference node to convert the image from Nuke linear to sRGB colour space.

Now that the input image is connected, change the object selected in the enumeration knob to see the effect it has on the output image. If the input image contains the selected object, the alpha channel of the output image will contain all ones, otherwise it will contain all zeros. Toggle the Use GPU if available and Optimise for speed and Memory knobs to confirm that they are also working as expected.