1"""Autograd functions for VQCs"""
2import numpy as np
3
4import torch
5from torch.autograd import Function
6from torch.autograd.function import once_differentiable
7
8from quantum_launcher import QuantumLauncher
9
10from qailab.utils import distribution_to_array
11from qailab.circuit.utils import filter_params, assign_input_weight
12
13# * Using template code from torch generates weird linter errors,
14# * might have to investigate later, ignoring for now as everything seems to work correctly.
15
16
17def _is_batch_input(t: torch.Tensor, num_input_params: int) -> bool:
18 return len(t.shape) == 2 and t.shape[1] == num_input_params
19
20
[docs]
21class ExpVQCFunction(Function): # pylint: disable=abstract-method
22 """Class implementing forward and backward calculations for ExpQLayer"""
23 @staticmethod
24 def _forward_single(
25 fn_in: torch.Tensor,
26 weight: torch.Tensor,
27 launcher_forward: QuantumLauncher,
28 ) -> torch.Tensor:
29 fn_in_numpy = fn_in.cpu().detach().numpy()
30 weight_numpy = weight.cpu().detach().numpy()
31
32 params = assign_input_weight(
33 launcher_forward.problem.instance,
34 fn_in_numpy,
35 weight_numpy
36 )
37 res = launcher_forward.run(parameters=params)
38 arr = distribution_to_array(res.distribution)
39 t = torch.tensor(arr, dtype=fn_in.dtype, requires_grad=True).to(fn_in.device)
40
41 return t
42
[docs]
43 @staticmethod
44 def forward( # pylint: disable=arguments-differ
45 fn_in: torch.Tensor,
46 weight: torch.Tensor,
47 launcher_forward: QuantumLauncher,
48 launcher_backward: QuantumLauncher # pylint: disable=unused-argument
49 ) -> torch.Tensor:
50 """
51 Calculation of forward pass.
52
53 Args:
54 fn_in (torch.Tensor): Input tensor.
55 weight (torch.Tensor): Layer weights.
56 launcher_forward (QuantumLauncher): Qlauncher with forward pass algorithm.
57 launcher_backward (QuantumLauncher):
58 Qlauncher with backward pass algorithm.
59 Not used in forward, but needed here as it will get passed to setup_context()
60
61 Returns:
62 torch.Tensor: Distribution of forward pass.
63 """
64
65 is_batch = _is_batch_input(fn_in, len(filter_params(launcher_forward.problem.instance, 'input')))
66
67 if is_batch:
68 return torch.stack([ExpVQCFunction._forward_single(single_in, weight, launcher_forward) for single_in in fn_in])
69 return ExpVQCFunction._forward_single(fn_in, weight, launcher_forward)
70
[docs]
71 @staticmethod
72 def setup_context(ctx, inputs, output):
73 """
74 Called after forward, saves args from forward to be later used in backward.
75
76 Args:
77 ctx: Context object that holds information.
78 inputs: args to forward()
79 outputs: outputs from forward()
80 """
81 fn_in, weight, launcher_forward, launcher_backward = inputs
82 ctx.save_for_backward(fn_in, weight, output)
83 ctx.launcher_forward = launcher_forward
84 ctx.launcher_backward = launcher_backward
85 ctx.is_batch = _is_batch_input(fn_in, len(filter_params(launcher_forward.problem.instance, 'input')))
86
87 @staticmethod
88 def _backward_single(
89 fn_in,
90 weight,
91 launcher_backward,
92 grad_output
93 ) -> tuple[torch.Tensor, torch.Tensor]:
94 fn_in_numpy = fn_in.cpu().detach().numpy()
95 weight_numpy = weight.cpu().detach().numpy()
96
97 params = assign_input_weight(
98 launcher_backward.problem.instance,
99 fn_in_numpy,
100 weight_numpy
101 )
102
103 res = launcher_backward.run(parameters=params, auto_bind=False)
104
105 out_grad_numpy = grad_output.cpu().detach().numpy()
106
107 grad_input = res.result['input'] @ out_grad_numpy
108 # Allow for weightless QNN layers
109 grad_weight = res.result['weight'] @ out_grad_numpy if len(res.result['weight']) > 0 else np.array([])
110
111 # Scale gradient values because we are optimizing weights initialized in range <0,2pi>
112 return (
113 torch.tensor(grad_input, dtype=fn_in.dtype).to(fn_in.device),
114 torch.tensor(grad_weight, dtype=weight.dtype).to(fn_in.device) * np.pi,
115 )
116
[docs]
117 @staticmethod
118 @once_differentiable
119 def backward( # pylint: disable=arguments-differ
120 ctx,
121 grad_output: torch.Tensor
122 ) -> tuple[torch.Tensor, torch.Tensor, None, None]:
123 """
124 Calculation of backward pass.
125
126 Args:
127 ctx: Context object supplied by autograd. Contains saved tensors and qlaunchers.
128 grad_output (torch.Tensor): Grad from next layer.
129
130 Returns:
131 tuple[torch.Tensor,torch.Tensor,None,None]:
132 Grad for inputs, Grad for weights, rest irrelevant.
133 (each forward argument needs to get something, but launchers don't need grad)
134 """
135 forward_tensors = ctx.saved_tensors
136 fn_in, weight = forward_tensors[:2]
137 launcher_backward = ctx.launcher_backward
138
139 if not ctx.is_batch:
140 return *ExpVQCFunction._backward_single(fn_in, weight, launcher_backward, grad_output), None, None
141
142 input_grads, weight_grads = [], []
143 for in_single, grad_single in zip(fn_in, grad_output):
144 igrad, wgrad = ExpVQCFunction._backward_single(in_single, weight, launcher_backward, grad_single)
145 input_grads.append(igrad)
146 weight_grads.append(wgrad)
147
148 return torch.stack(input_grads), torch.stack(weight_grads), None, None
149
150
[docs]
151class ArgMax(Function): # pylint: disable=abstract-method
152 """
153 ArgMax function. Propagates the sum of gradient on argmax index, rest is zero.
154
155 https://discuss.pytorch.org/t/differentiable-argmax/33020
156 """
[docs]
157 @staticmethod
158 def forward(fn_in): # pylint: disable=arguments-differ
159 """
160 Forward run.
161
162 Args:
163 fn_in (torch.Tensor): Input tensor.
164
165 Returns:
166 torch.Tensor: First index(es) of maximum elements.
167 """
168 return torch.tensor(torch.argmax(fn_in, dim=-1, keepdim=True), dtype=fn_in.dtype, requires_grad=True)
169
[docs]
170 @staticmethod
171 def setup_context(ctx, inputs, output):
172 """Save tensors for backward pass"""
173 ctx.save_for_backward(*inputs, output)
174
[docs]
175 @staticmethod
176 def backward( # pylint: disable=arguments-differ
177 ctx,
178 grad_output: torch.Tensor
179 ) -> tuple[torch.Tensor]:
180 """
181 Calculation of backward pass.
182
183 Args:
184 ctx: Context object supplied by autograd.
185 grad_output (torch.Tensor): Grad from next layer.
186
187 Returns:
188 tuple[torch.Tensor]: Grad w.r.t. input.
189 """
190 fn_in, idx = ctx.saved_tensors
191 grad_input = torch.zeros(fn_in.shape, device=fn_in.device, dtype=fn_in.dtype)
192 grad_input.scatter_(-1, torch.tensor(idx, dtype=torch.int64), grad_output.sum(-1, keepdim=True))
193 return (grad_input,)