{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Quantum Regressor\n", "\n", "## Quantum Doctor\n", "\n", "### How to train a hybrid-quantum neural network to predict diabetes progression in patients.\n", "\n", "In this tutorial you will learn how to create a hybrid neural network, which utilizes a quantum regression layer as its output. \n", "As our training dataset we will use a subset of the [diabetes toy dataset](https://www4.stat.ncsu.edu/~boos/var.select/diabetes.html) available via `scikit-learn`.\n", "\n", "## Table of contents:\n", "- [Dataset preparation and exploration](#dataset-preparation)\n", "- [Model creation](#model-creation)\n", "- [Model training and evaluation](#model-evaluation)\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.simplefilter('ignore')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset preparation\n", "\n", "We can load our dataset via the `load_diabetes` method available from `sklearn.datasets` as it is one of the toy datasets provided by the library. \n", "You might notice that the $X$ values are already processed. This will save us time. \n", "We will only use 200 samples from the 442 total, 100 for training and 100 for evaluation." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from sklearn.datasets import load_diabetes\n", "\n", "diabetes = load_diabetes()\n", "\n", "X = pd.DataFrame(diabetes.data[:200],columns=diabetes.feature_names)\n", "y = pd.Series(diabetes.target[:200])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The diabetes dataset consists of 10 numeric variables: age, sex (2 possible values), body mass index (*bmi*), average blood pressure (*bp*), as well as results of six blood serum measurements (described on the [sklearn dataset page](https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset))." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "application/vnd.microsoft.datawrangler.viewer.v0+json": { "columns": [ { "name": "index", "rawType": "int64", "type": "integer" }, { "name": "age", "rawType": "float64", "type": "float" }, { "name": "sex", "rawType": "float64", "type": "float" }, { "name": "bmi", "rawType": "float64", "type": "float" }, { "name": "bp", "rawType": "float64", "type": "float" }, { "name": "s1", "rawType": "float64", "type": "float" }, { "name": "s2", "rawType": "float64", "type": "float" }, { "name": "s3", "rawType": "float64", "type": "float" }, { "name": "s4", "rawType": "float64", "type": "float" }, { "name": "s5", "rawType": "float64", "type": "float" }, { "name": "s6", "rawType": "float64", "type": "float" } ], "conversionMethod": "pd.DataFrame", "ref": "4ad250cf-a0e4-4a2f-9519-1742db7abc98", "rows": [ [ "0", "0.038075906433423026", "0.05068011873981862", "0.061696206518683294", "0.0218723855140367", "-0.04422349842444599", "-0.03482076283769895", "-0.04340084565202491", "-0.002592261998183278", "0.019907486170462722", "-0.01764612515980379" ], [ "1", "-0.0018820165277906047", "-0.044641636506989144", "-0.051474061238800654", "-0.02632752814785296", "-0.008448724111216851", "-0.019163339748222204", "0.07441156407875721", "-0.03949338287409329", "-0.0683315470939731", "-0.092204049626824" ], [ "2", "0.08529890629667548", "0.05068011873981862", "0.04445121333659049", "-0.00567042229275739", "-0.04559945128264711", "-0.03419446591411989", "-0.03235593223976409", "-0.002592261998183278", "0.002861309289833047", "-0.025930338989472702" ], [ "3", "-0.0890629393522567", "-0.044641636506989144", "-0.011595014505211082", "-0.03665608107540074", "0.01219056876179996", "0.02499059336410222", "-0.036037570043851025", "0.03430885887772673", "0.022687744966501246", "-0.009361911330134878" ], [ "4", "0.005383060374248237", "-0.044641636506989144", "-0.03638469220446948", "0.0218723855140367", "0.003934851612593237", "0.015596139510416171", "0.008142083605192267", "-0.002592261998183278", "-0.03198763948805312", "-0.04664087356364498" ] ], "shape": { "columns": 10, "rows": 5 } }, "text/html": [ "
\n", " | age | \n", "sex | \n", "bmi | \n", "bp | \n", "s1 | \n", "s2 | \n", "s3 | \n", "s4 | \n", "s5 | \n", "s6 | \n", "
---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0.038076 | \n", "0.050680 | \n", "0.061696 | \n", "0.021872 | \n", "-0.044223 | \n", "-0.034821 | \n", "-0.043401 | \n", "-0.002592 | \n", "0.019907 | \n", "-0.017646 | \n", "
1 | \n", "-0.001882 | \n", "-0.044642 | \n", "-0.051474 | \n", "-0.026328 | \n", "-0.008449 | \n", "-0.019163 | \n", "0.074412 | \n", "-0.039493 | \n", "-0.068332 | \n", "-0.092204 | \n", "
2 | \n", "0.085299 | \n", "0.050680 | \n", "0.044451 | \n", "-0.005670 | \n", "-0.045599 | \n", "-0.034194 | \n", "-0.032356 | \n", "-0.002592 | \n", "0.002861 | \n", "-0.025930 | \n", "
3 | \n", "-0.089063 | \n", "-0.044642 | \n", "-0.011595 | \n", "-0.036656 | \n", "0.012191 | \n", "0.024991 | \n", "-0.036038 | \n", "0.034309 | \n", "0.022688 | \n", "-0.009362 | \n", "
4 | \n", "0.005383 | \n", "-0.044642 | \n", "-0.036385 | \n", "0.021872 | \n", "0.003935 | \n", "0.015596 | \n", "0.008142 | \n", "-0.002592 | \n", "-0.031988 | \n", "-0.046641 | \n", "