Source code for qailab.qlauncher.passes.forward
1""" Forward pass algorithm implementation in quantum_launcher. """
2from collections import defaultdict
3from collections.abc import Callable
4from typing import Any
5from quantum_launcher.base import Algorithm
6from quantum_launcher.base.base import Backend, Problem, Result
7from quantum_launcher.routines.qiskit_routines import QiskitBackend
8from qiskit.primitives.containers.primitive_result import PrimitiveResult
9from qiskit.primitives.base.sampler_result import SamplerResult
10
11from qailab.utils import number_to_bit_tuple
12
13
[docs]
14class ForwardPass(Algorithm):
15 """Forward Pass"""
16 _algorithm_format = 'none'
17
18 def __init__(self, shots: int = 1024) -> None:
19 """Forward pass implementation for QLauncher.
20
21 Args:
22 shots (int): Number of shots. Defaults to 1024.
23 """
24 self.shots = shots
25 super().__init__()
26
[docs]
27 def run(self, problem: Problem, backend: Backend, formatter: Callable[..., Any] | None = None) -> Result:
28 if formatter is None:
29 raise ValueError('Formatter for Forward pass not found!')
30 if not isinstance(backend, QiskitBackend):
31 raise ValueError('Wrong sampler given into')
32 pubs = formatter(problem)
33 sampler = backend.sampler
34 job = sampler.run(pubs, shots=self.shots)
35 result = job.result()
36 if isinstance(result, PrimitiveResult):
37 distribution = self._extract_results_v2(result)[0]
38 elif isinstance(result, SamplerResult):
39 distribution = self._extract_results_v1(result)[0]
40 else:
41 raise ValueError(f'Result with type: {type(result)} is not supported')
42 return Result('', 0, '', 0, distribution, {}, self.shots, 0, 0, None) # Results are not picklable
43
44 def _extract_results_v2(self, result: PrimitiveResult) -> list[dict]:
45 distributions = []
46 for pub in result._pub_results: # pylint: disable=protected-access
47 data = pub.data['c'].array
48 num_qubits = pub.data['c'].num_bits
49 distribution = defaultdict(float)
50 for datum_arr in data:
51 # Qiskit splits measurements into 8 bit chunks for some godforsaken reason.
52 tot_num = 0
53 for i, v in enumerate(datum_arr[::-1]):
54 tot_num += int(v) * (2**(i * 8))
55
56 distribution[number_to_bit_tuple(tot_num, num_qubits)] += 1 / self.shots
57 distributions.append(distribution)
58 return distributions
59
60 def _extract_results_v1(self, result: SamplerResult) -> list[dict]:
61 distributions = []
62 for quasi_dist in result.quasi_dists:
63 distribution = {}
64 for key, value in quasi_dist.items():
65 distribution[number_to_bit_tuple(key, quasi_dist._num_bits)] = value # pylint: disable=protected-access
66 distributions.append(distribution)
67 return distributions