Accessing Attributes of Attributes
TorchScript does not support inheritance, which means that you will have to ensure that your model classes do not have any subclass definitions. (See here for more details.)
Because one model cannot be defined as a subclass of another, you may find that you need to define one of your models as an attribute of another, FirstModel.SecondModel, for instance. Since Nuke does not allow knobs to have names with ‘.’ characters in them, we cannot directly connect the attributes of SecondModel to a knob in Nuke. In this case, we need to alter FirstModel to accept an argument controlling SecondModel.attribute. This method can also be used if the attribute you would like to control is nested within a torch.nn.Sequential object.
TorchScript
Consider the case in which our Cifar10ResNet class is an attribute of another class called ObjectDetection:
class ObjectDetection(torch.nn.Module):
def __init__(self):
"""
This class is a wrapper around the Cifar10ResNet model.
"""
super(ObjectDetection, self).__init__()
self.Cifar10ResNet = Cifar10ResNet()
def forward(self, input):
output = self.Cifar10ResNet.forward(input)
return output
class Cifar10ResNet(torch.nn.Module):
"""
This class is a wrapper for our custom ResNet model.
"""
def __init__(self):
super(Cifar10ResNet, self).__init__()
self.label = 1
self.model = ResNet(pretrained=True)
self.model.eval()
def forward(self, input):
"""
:param input: A torch.Tensor of size 1 x 3 x H x W representing the input image
:return: A torch.Tensor of size 1 x 1 x H x W of zeros or ones
"""
modelOutput = self.model.forward(input)
modelLabel = int(torch.argmax(modelOutput[0]))
if modelLabel == self.label:
output = torch.ones(1, 1, input.shape[2], input.shape[3])
else:
output = torch.zeros(1, 1, input.shape[2], input.shape[3])
return output
In order to use a knob to control the parameter label defined in Cifar10ResNet, we need to make the following changes to the ObjectDetection class:
class ObjectDetection(torch.nn.Module):
"""
This class is a wrapper around the Cifar10ResNet model.
"""
def __init__(self, userLabel = 0):
super(ObjectDetection, self).__init__()
self.Cifar10ResNet = Cifar10ResNet()
self.userLabel = userLabel
def forward(self, input):
self.Cifar10ResNet.label = self.userLabel
output = self.Cifar10ResNet.forward(input)
return output
CatFileCreator
In CatFileCreator you can now create a custom knob and define the knob name as “userLabel”. This will allow the value of the userLabel knob to get assigned to ObjectDetection.Cifar10ResNet.label without having to use ‘.’ characters within the knob name.