{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Quantum Regressor\n", "\n", "## Quantum Doctor\n", "\n", "### How to train a hybrid-quantum neural network to predict diabetes progression in patients.\n", "\n", "In this tutorial you will learn how to create a hybrid neural network, which utilizes a quantum regression layer as its output. \n", "As our training dataset we will use a subset of the [diabetes toy dataset](https://www4.stat.ncsu.edu/~boos/var.select/diabetes.html) available via `scikit-learn`.\n", "\n", "## Table of contents:\n", "- [Dataset preparation and exploration](#dataset-preparation)\n", "- [Model creation](#model-creation)\n", "- [Model training and evaluation](#model-evaluation)\n" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.simplefilter('ignore')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset preparation\n", "\n", "We can load our dataset via the `load_diabetes` method available from `sklearn.datasets` as it is one of the toy datasets provided by the library. \n", "You might notice that the $X$ values are already processed. This will save us time. \n", "We will only use 200 samples from the 442 total, 100 for training and 100 for evaluation." ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from sklearn.datasets import load_diabetes\n", "\n", "diabetes = load_diabetes()\n", "\n", "X = pd.DataFrame(diabetes.data[:200],columns=diabetes.feature_names)\n", "y = pd.Series(diabetes.target[:200])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The diabetes dataset consists of 10 numeric variables: age, sex (2 possible values), body mass index (*bmi*), average blood pressure (*bp*), as well as results of six blood serum measurements (described on the [sklearn dataset page](https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset))." ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | age | \n", "sex | \n", "bmi | \n", "bp | \n", "s1 | \n", "s2 | \n", "s3 | \n", "s4 | \n", "s5 | \n", "s6 | \n", "
---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0.038076 | \n", "0.050680 | \n", "0.061696 | \n", "0.021872 | \n", "-0.044223 | \n", "-0.034821 | \n", "-0.043401 | \n", "-0.002592 | \n", "0.019907 | \n", "-0.017646 | \n", "
1 | \n", "-0.001882 | \n", "-0.044642 | \n", "-0.051474 | \n", "-0.026328 | \n", "-0.008449 | \n", "-0.019163 | \n", "0.074412 | \n", "-0.039493 | \n", "-0.068332 | \n", "-0.092204 | \n", "
2 | \n", "0.085299 | \n", "0.050680 | \n", "0.044451 | \n", "-0.005670 | \n", "-0.045599 | \n", "-0.034194 | \n", "-0.032356 | \n", "-0.002592 | \n", "0.002861 | \n", "-0.025930 | \n", "
3 | \n", "-0.089063 | \n", "-0.044642 | \n", "-0.011595 | \n", "-0.036656 | \n", "0.012191 | \n", "0.024991 | \n", "-0.036038 | \n", "0.034309 | \n", "0.022688 | \n", "-0.009362 | \n", "
4 | \n", "0.005383 | \n", "-0.044642 | \n", "-0.036385 | \n", "0.021872 | \n", "0.003935 | \n", "0.015596 | \n", "0.008142 | \n", "-0.002592 | \n", "-0.031988 | \n", "-0.046641 | \n", "