Source code for qailab.torch.autograd

  1"""Autograd functions for VQCs"""
  2import numpy as np
  3
  4import torch
  5from torch.autograd import Function
  6from torch.autograd.function import once_differentiable
  7
  8from quantum_launcher import QuantumLauncher
  9
 10from qailab.utils import distribution_to_array
 11from qailab.circuit.utils import filter_params, assign_input_weight
 12
 13# * Using template code from torch generates weird linter errors,
 14# * might have to investigate later, ignoring for now as everything seems to work correctly.
 15
 16
 17def _is_batch_input(t: torch.Tensor, num_input_params: int) -> bool:
 18    return len(t.shape) == 2 and t.shape[1] == num_input_params
 19
 20
[docs] 21class ExpVQCFunction(Function): # pylint: disable=abstract-method 22 """Class implementing forward and backward calculations for ExpQLayer""" 23 @staticmethod 24 def _forward_single( 25 fn_in: torch.Tensor, 26 weight: torch.Tensor, 27 launcher_forward: QuantumLauncher, 28 ) -> torch.Tensor: 29 fn_in_numpy = fn_in.cpu().detach().numpy() 30 weight_numpy = weight.cpu().detach().numpy() 31 32 params = assign_input_weight( 33 launcher_forward.problem.instance, 34 fn_in_numpy, 35 weight_numpy 36 ) 37 res = launcher_forward.run(parameters=params) 38 arr = distribution_to_array(res.distribution) 39 t = torch.tensor(arr, dtype=fn_in.dtype, requires_grad=True).to(fn_in.device) 40 41 return t 42
[docs] 43 @staticmethod 44 def forward( # pylint: disable=arguments-differ 45 fn_in: torch.Tensor, 46 weight: torch.Tensor, 47 launcher_forward: QuantumLauncher, 48 launcher_backward: QuantumLauncher # pylint: disable=unused-argument 49 ) -> torch.Tensor: 50 """ 51 Calculation of forward pass. 52 53 Args: 54 fn_in (torch.Tensor): Input tensor. 55 weight (torch.Tensor): Layer weights. 56 launcher_forward (QuantumLauncher): Qlauncher with forward pass algorithm. 57 launcher_backward (QuantumLauncher): 58 Qlauncher with backward pass algorithm. 59 Not used in forward, but needed here as it will get passed to setup_context() 60 61 Returns: 62 torch.Tensor: Distribution of forward pass. 63 """ 64 65 is_batch = _is_batch_input(fn_in, len(filter_params(launcher_forward.problem.instance, 'input'))) 66 67 if is_batch: 68 return torch.stack([ExpVQCFunction._forward_single(single_in, weight, launcher_forward) for single_in in fn_in]) 69 return ExpVQCFunction._forward_single(fn_in, weight, launcher_forward)
70
[docs] 71 @staticmethod 72 def setup_context(ctx, inputs, output): 73 """ 74 Called after forward, saves args from forward to be later used in backward. 75 76 Args: 77 ctx: Context object that holds information. 78 inputs: args to forward() 79 outputs: outputs from forward() 80 """ 81 fn_in, weight, launcher_forward, launcher_backward = inputs 82 ctx.save_for_backward(fn_in, weight, output) 83 ctx.launcher_forward = launcher_forward 84 ctx.launcher_backward = launcher_backward 85 ctx.is_batch = _is_batch_input(fn_in, len(filter_params(launcher_forward.problem.instance, 'input')))
86 87 @staticmethod 88 def _backward_single( 89 fn_in, 90 weight, 91 launcher_backward, 92 grad_output 93 ) -> tuple[torch.Tensor, torch.Tensor]: 94 fn_in_numpy = fn_in.cpu().detach().numpy() 95 weight_numpy = weight.cpu().detach().numpy() 96 97 params = assign_input_weight( 98 launcher_backward.problem.instance, 99 fn_in_numpy, 100 weight_numpy 101 ) 102 103 res = launcher_backward.run(parameters=params, auto_bind=False) 104 105 out_grad_numpy = grad_output.cpu().detach().numpy() 106 107 grad_input = res.result['input'] @ out_grad_numpy 108 # Allow for weightless QNN layers 109 grad_weight = res.result['weight'] @ out_grad_numpy if len(res.result['weight']) > 0 else np.array([]) 110 111 # Scale gradient values because we are optimizing weights initialized in range <0,2pi> 112 return ( 113 torch.tensor(grad_input, dtype=fn_in.dtype).to(fn_in.device), 114 torch.tensor(grad_weight, dtype=weight.dtype).to(fn_in.device) * np.pi, 115 ) 116
[docs] 117 @staticmethod 118 @once_differentiable 119 def backward( # pylint: disable=arguments-differ 120 ctx, 121 grad_output: torch.Tensor 122 ) -> tuple[torch.Tensor, torch.Tensor, None, None]: 123 """ 124 Calculation of backward pass. 125 126 Args: 127 ctx: Context object supplied by autograd. Contains saved tensors and qlaunchers. 128 grad_output (torch.Tensor): Grad from next layer. 129 130 Returns: 131 tuple[torch.Tensor,torch.Tensor,None,None]: 132 Grad for inputs, Grad for weights, rest irrelevant. 133 (each forward argument needs to get something, but launchers don't need grad) 134 """ 135 forward_tensors = ctx.saved_tensors 136 fn_in, weight = forward_tensors[:2] 137 launcher_backward = ctx.launcher_backward 138 139 if not ctx.is_batch: 140 return *ExpVQCFunction._backward_single(fn_in, weight, launcher_backward, grad_output), None, None 141 142 input_grads, weight_grads = [], [] 143 for in_single, grad_single in zip(fn_in, grad_output): 144 igrad, wgrad = ExpVQCFunction._backward_single(in_single, weight, launcher_backward, grad_single) 145 input_grads.append(igrad) 146 weight_grads.append(wgrad) 147 148 return torch.stack(input_grads), torch.stack(weight_grads), None, None
149 150
[docs] 151class ArgMax(Function): # pylint: disable=abstract-method 152 """ 153 ArgMax function. Propagates the sum of gradient on argmax index, rest is zero. 154 155 https://discuss.pytorch.org/t/differentiable-argmax/33020 156 """
[docs] 157 @staticmethod 158 def forward(fn_in): # pylint: disable=arguments-differ 159 """ 160 Forward run. 161 162 Args: 163 fn_in (torch.Tensor): Input tensor. 164 165 Returns: 166 torch.Tensor: First index(es) of maximum elements. 167 """ 168 return torch.tensor(torch.argmax(fn_in, dim=-1, keepdim=True), dtype=fn_in.dtype, requires_grad=True)
169
[docs] 170 @staticmethod 171 def setup_context(ctx, inputs, output): 172 """Save tensors for backward pass""" 173 ctx.save_for_backward(*inputs, output)
174
[docs] 175 @staticmethod 176 def backward( # pylint: disable=arguments-differ 177 ctx, 178 grad_output: torch.Tensor 179 ) -> tuple[torch.Tensor]: 180 """ 181 Calculation of backward pass. 182 183 Args: 184 ctx: Context object supplied by autograd. 185 grad_output (torch.Tensor): Grad from next layer. 186 187 Returns: 188 tuple[torch.Tensor]: Grad w.r.t. input. 189 """ 190 fn_in, idx = ctx.saved_tensors 191 grad_input = torch.zeros(fn_in.shape, device=fn_in.device, dtype=fn_in.dtype) 192 grad_input.scatter_(-1, torch.tensor(idx, dtype=torch.int64), grad_output.sum(-1, keepdim=True)) 193 return (grad_input,)