{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Learn CPTs of Bayesian Netork\n", "\n", "For the design of Bayesian Networks for a given problem domain, one can follow one of the following approaches:\n", "\n", "1. topology (nodes and arcs) and the conditional propabilities are configured by applying expert knowledge, i.e. experts determine the relevant variables, their dependencies and estimate the conditional probabilities \n", "\n", "2. topology is determined by experts, but the conditional probabilities of the CPTs are learned from data\n", "\n", "3. topology and conditional probabilities are learned from data\n", "\n", "\n", "In this section it is shown how [pyAgrum](https://pyagrum.readthedocs.io/en/0.16.3/index.html) can be applied for the second option, i.e. to learn the Conditional Probability Tables (CPTs) of Bayesian Networks, whose topology is known. \n", "\n", "As demonstrated for example in [learn Bayesian Network from covid-19 data](https://github.com/AlvaroCorrales/BayesianNetworks/blob/main/Bayesian_Networks_Tutorial_covid-19.ipynb) pyAgrum can also applied for option 3, where the topology and the CPTs are learned from data.\n", "\n", "In contrast to Machine Learning algorithms, **Bayesian Networks provide the important capability to integrate knowledge from data with expert knowledge.**\n", "\n", "In order to demonstrate the learning capability we apply two Bayesian Networks. Both have the identical topology, but they are initialized with different random CPTs. Since a Bayesian Network represents a joint probability distribution (JPT), data can be generated from this network by sampling according to the networks JPT. We sample 10000 instances from one network and apply this data to learn the CPTs of the other. By comparing the distance between the JPTs before and after the second network has been trained with the data from the first network, one can verify, that the second network learns to get similar to the first network, by learning from the first networks data. \n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2019-02-25T16:09:48.518018Z", "start_time": "2019-02-25T16:09:48.332299Z" }, "execution": { "iopub.execute_input": "2021-04-02T20:30:41.092604Z", "iopub.status.busy": "2021-04-02T20:30:41.092604Z", "iopub.status.idle": "2021-04-02T20:30:41.436603Z", "shell.execute_reply": "2021-04-02T20:30:41.435602Z", "shell.execute_reply.started": "2021-04-02T20:30:41.092604Z" } }, "outputs": [], "source": [ "%matplotlib inline\n", "from pylab import *\n", "import matplotlib.pyplot as plt\n", "import os\n", "import pyAgrum as gum\n", "import pyAgrum.lib.notebook as gnb" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loading two BNs\n", "\n", "Two identical Bayes Nets for the *Visit to Asia?*-problem (see section [Bayesian Networks with pyAgrum](BayesNetAsia.ipynb) are loaded from disk." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2019-02-25T16:09:53.158026Z", "start_time": "2019-02-25T16:09:48.529386Z" }, "execution": { "iopub.execute_input": "2021-04-02T20:30:41.516604Z", "iopub.status.busy": "2021-04-02T20:30:41.516604Z", "iopub.status.idle": "2021-04-02T20:30:41.715602Z", "shell.execute_reply": "2021-04-02T20:30:41.715602Z", "shell.execute_reply.started": "2021-04-02T20:30:41.516604Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "G\n", "\n", "\n", "E\n", "\n", "\n", "E\n", "\n", "\n", "\n", "\n", "\n", "D\n", "\n", "\n", "D\n", "\n", "\n", "\n", "\n", "\n", "E->D\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "X\n", "\n", "\n", "X\n", "\n", "\n", "\n", "\n", "\n", "E->X\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "S\n", "\n", "\n", "S\n", "\n", "\n", "\n", "\n", "\n", "L\n", "\n", "\n", "L\n", "\n", "\n", "\n", "\n", "\n", "S->L\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "B\n", "\n", "\n", "B\n", "\n", "\n", "\n", "\n", "\n", "S->B\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "T\n", "\n", "\n", "T\n", "\n", "\n", "\n", "\n", "\n", "T->E\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "L->E\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "A\n", "\n", "\n", "A\n", "\n", "\n", "\n", "\n", "\n", "A->T\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "B->D\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
\n", "\n", "G\n", "\n", "\n", "E\n", "\n", "\n", "E\n", "\n", "\n", "\n", "\n", "\n", "D\n", "\n", "\n", "D\n", "\n", "\n", "\n", "\n", "\n", "E->D\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "X\n", "\n", "\n", "X\n", "\n", "\n", "\n", "\n", "\n", "E->X\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "S\n", "\n", "\n", "S\n", "\n", "\n", "\n", "\n", "\n", "L\n", "\n", "\n", "L\n", "\n", "\n", "\n", "\n", "\n", "S->L\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "B\n", "\n", "\n", "B\n", "\n", "\n", "\n", "\n", "\n", "S->B\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "T\n", "\n", "\n", "T\n", "\n", "\n", "\n", "\n", "\n", "T->E\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "L->E\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "A\n", "\n", "\n", "A\n", "\n", "\n", "\n", "\n", "\n", "A->T\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "B->D\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
First bn
Second bn
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "bn=gum.loadBN(os.path.join(\"out\",\"VisitAsia.bif\"))\n", "bn2=gum.loadBN(os.path.join(\"out\",\"VisitAsia.bif\"))\n", "\n", "gnb.sideBySide(bn,bn2,\n", " captions=['First bn','Second bn'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As shown below, both BNs have the same CPTs" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "
D
E
B
0
1
0
0
0.90000.1000
1
0.20000.8000
1
0
0.30000.7000
1
0.10000.9000
\n", "\n", "\n", "\n", "\n", "\n", "\n", "
D
E
B
0
1
0
0
0.90000.1000
1
0.20000.8000
1
0
0.30000.7000
1
0.10000.9000
CPT Dyspnoae bn
CPT Dyspnoae bn2
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "gnb.sideBySide(bn.cpt(\"D\"),bn2.cpt(\"D\"),\n", " captions=['CPT Dyspnoae bn','CPT Dyspnoae bn2'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Randomizing the parameters\n", "\n", "Next, for both BNs the values of the CPTs are randomized" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2019-02-25T16:09:53.164096Z", "start_time": "2019-02-25T16:09:53.161235Z" }, "execution": { "iopub.execute_input": "2021-04-02T20:30:41.717603Z", "iopub.status.busy": "2021-04-02T20:30:41.716603Z", "iopub.status.idle": "2021-04-02T20:30:41.732606Z", "shell.execute_reply": "2021-04-02T20:30:41.731605Z", "shell.execute_reply.started": "2021-04-02T20:30:41.717603Z" } }, "outputs": [], "source": [ "bn.generateCPTs()\n", "bn2.generateCPTs()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As can be seen below, the CPTs of both BNs now have new and different values:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "
D
E
B
0
1
0
0
0.26380.7362
1
0.41270.5873
1
0
0.07750.9225
1
0.53530.4647
\n", "\n", "\n", "\n", "\n", "\n", "\n", "
D
E
B
0
1
0
0
0.54360.4564
1
0.32820.6718
1
0
0.86120.1388
1
0.01760.9824
CPT Dyspnoae bn
CPT Dyspnoae bn2
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "gnb.sideBySide(bn.cpt(\"D\"),bn2.cpt(\"D\"),\n", " captions=['CPT Dyspnoae bn','CPT Dyspnoae bn2'])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2019-02-25T16:09:53.180035Z", "start_time": "2019-02-25T16:09:53.165848Z" }, "execution": { "iopub.execute_input": "2021-04-02T20:30:41.733605Z", "iopub.status.busy": "2021-04-02T20:30:41.733605Z", "iopub.status.idle": "2021-04-02T20:30:41.748603Z", "shell.execute_reply": "2021-04-02T20:30:41.747603Z", "shell.execute_reply.started": "2021-04-02T20:30:41.733605Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "
L
S
0
1
0
0.33960.6604
1
0.43450.5655
\n", "\n", "\n", "\n", "\n", "
L
S
0
1
0
0.34960.6504
1
0.90990.0901
CPT Lung Cancer bn
CPT Lung Cancer bn2
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "gnb.sideBySide(bn.cpt(\"L\"),bn2.cpt(\"L\"),\n", " captions=['CPT Lung Cancer bn','CPT Lung Cancer bn2'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Exact KL-divergence " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A Bayesian Network represents a joint probability distribution (JPT). For measuring the similarity of two probability distributions, different metrics exist. Here, we apply the [Kullback-Leibler (KL) Divergence](https://en.wikipedia.org/wiki/Kullback–Leibler_divergence). The lower the KL-divergence, the higher the similarity of the two distributions. " ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2019-02-25T16:09:53.188124Z", "start_time": "2019-02-25T16:09:53.182178Z" }, "execution": { "iopub.execute_input": "2021-04-02T20:30:41.749605Z", "iopub.status.busy": "2021-04-02T20:30:41.749605Z", "iopub.status.idle": "2021-04-02T20:30:41.763606Z", "shell.execute_reply": "2021-04-02T20:30:41.762606Z", "shell.execute_reply.started": "2021-04-02T20:30:41.749605Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.8515800881220987\n" ] } ], "source": [ "g1=gum.ExactBNdistance(bn,bn2)\n", "before_learning=g1.compute()\n", "print(before_learning['klPQ'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Just to be sure that the distance between a BN and itself is 0 :" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2019-02-25T16:09:53.194564Z", "start_time": "2019-02-25T16:09:53.190123Z" }, "execution": { "iopub.execute_input": "2021-04-02T20:30:41.765607Z", "iopub.status.busy": "2021-04-02T20:30:41.764605Z", "iopub.status.idle": "2021-04-02T20:30:41.778605Z", "shell.execute_reply": "2021-04-02T20:30:41.777605Z", "shell.execute_reply.started": "2021-04-02T20:30:41.765607Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.0\n" ] } ], "source": [ "g0=gum.ExactBNdistance(bn,bn)\n", "print(g0.compute()['klPQ'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As shown below, the `compute()`-method of class `ExactBNdistance()`, does not only provide the Kullback-Leibler Divergence, but also other distance measures. However, here we just apply KL." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'klPQ': 2.8515800881220987,\n", " 'errorPQ': 0,\n", " 'klQP': 2.1525276737980534,\n", " 'errorQP': 0,\n", " 'hellinger': 0.8325838211096981,\n", " 'bhattacharya': 0.4255625805347223,\n", " 'jensen-shannon': 0.43254545996219734}" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "before_learning" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate a data from the original BN\n", "\n", "By applying the methode `generateCSV()` one can sample data from a Bayesian Network. In the code-cell below 10000 samples, each describing the values of the 8 random variables for one fictional patient, are generated and saved in `out/test.csv`." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2019-02-25T16:09:53.318968Z", "start_time": "2019-02-25T16:09:53.196422Z" }, "execution": { "iopub.execute_input": "2021-04-02T20:30:41.779604Z", "iopub.status.busy": "2021-04-02T20:30:41.779604Z", "iopub.status.idle": "2021-04-02T20:30:41.950605Z", "shell.execute_reply": "2021-04-02T20:30:41.950605Z", "shell.execute_reply.started": "2021-04-02T20:30:41.779604Z" }, "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "-72651.2778772939" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gum.generateCSV(bn,os.path.join(\"out\",\"test.csv\"),10000,False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Learn CPTs of Bayesian Network from Data\n", "\n", "Next, we will apply the data, as sampled from Bayesian Network `bn` above, to learn the CPTs of another Bayesian Network `bnx` with the same topology as `bn`. We expect, that after a successfull learning process, the KL-Divergence between `bn` and `bn2` is low, i.e. both nets are similar.\n", "\n", "There exist different options to learn the CPTs of a Bayesian Network. Below, we implement the following 3 options: \n", "\n", "1. the `BNLearner()`-class from pyAgrum\n", "2. the `BNLearner()`-class from pyAgrum with Laplace Smoothing\n", "3. the [pandas crosstab() method](https://pandas.pydata.org/pandas-docs/version/0.23/generated/pandas.crosstab.html) for calculating CPTs\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Apply pyAgrum Learners\n", "\n", "**BNLearner() without smoothing:**" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2021-04-02T20:30:41.952606Z", "iopub.status.busy": "2021-04-02T20:30:41.951606Z", "iopub.status.idle": "2021-04-02T20:30:42.012604Z", "shell.execute_reply": "2021-04-02T20:30:42.011604Z", "shell.execute_reply.started": "2021-04-02T20:30:41.952606Z" } }, "outputs": [], "source": [ "learner=gum.BNLearner(os.path.join(\"out\",\"test.csv\"),bn) \n", "bn3=learner.learnParameters(bn.dag())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**BNLearner() without Laplace Smoothing:**" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2021-04-02T20:30:41.952606Z", "iopub.status.busy": "2021-04-02T20:30:41.951606Z", "iopub.status.idle": "2021-04-02T20:30:42.012604Z", "shell.execute_reply": "2021-04-02T20:30:42.011604Z", "shell.execute_reply.started": "2021-04-02T20:30:41.952606Z" } }, "outputs": [], "source": [ "learner=gum.BNLearner(os.path.join(\"out\",\"test.csv\"),bn) \n", "learner.useAprioriSmoothing(100) # a count C is replaced by C+100\n", "bn4=learner.learnParameters(bn.dag())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As shown below, both approaches learn Bayesian Networks, which have a small KL-divergence to the Bayesian Network, from which training data has been sampled:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2021-04-02T20:30:41.952606Z", "iopub.status.busy": "2021-04-02T20:30:41.951606Z", "iopub.status.idle": "2021-04-02T20:30:42.012604Z", "shell.execute_reply": "2021-04-02T20:30:42.011604Z", "shell.execute_reply.started": "2021-04-02T20:30:41.952606Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "KL-Divergence for option without smoothing :0.0010387354935079448\n", "KL-Divergence for option with smooting(100):0.003478427487496688\n" ] } ], "source": [ "after_pyAgrum_learning=gum.ExactBNdistance(bn,bn3).compute()\n", "after_pyAgrum_learning_with_laplace=gum.ExactBNdistance(bn,bn4).compute()\n", "print(\"KL-Divergence for option without smoothing :{}\".format(after_pyAgrum_learning['klPQ']))\n", "print(\"KL-Divergence for option with smooting(100):{}\".format(after_pyAgrum_learning_with_laplace['klPQ']))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Apply pandas to learn CPTs" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "ExecuteTime": { "end_time": "2019-02-25T16:09:53.651559Z", "start_time": "2019-02-25T16:09:53.321004Z" }, "execution": { "iopub.execute_input": "2021-04-02T20:30:42.013605Z", "iopub.status.busy": "2021-04-02T20:30:42.013605Z", "iopub.status.idle": "2021-04-02T20:30:42.327603Z", "shell.execute_reply": "2021-04-02T20:30:42.325603Z", "shell.execute_reply.started": "2021-04-02T20:30:42.013605Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
XBSTEADL
000100101
101100110
201010010
310101110
401010110
\n", "
" ], "text/plain": [ " X B S T E A D L\n", "0 0 0 1 0 0 1 0 1\n", "1 0 1 1 0 0 1 1 0\n", "2 0 1 0 1 0 0 1 0\n", "3 1 0 1 0 1 1 1 0\n", "4 0 1 0 1 0 1 1 0" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas\n", "df=pandas.read_csv(os.path.join(\"out\",\"test.csv\"))\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use the crosstab function in pandas, to determine conditional counts:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "ExecuteTime": { "end_time": "2019-02-25T16:09:53.677608Z", "start_time": "2019-02-25T16:09:53.653478Z" }, "execution": { "iopub.execute_input": "2021-04-02T20:30:42.328604Z", "iopub.status.busy": "2021-04-02T20:30:42.328604Z", "iopub.status.idle": "2021-04-02T20:30:42.372604Z", "shell.execute_reply": "2021-04-02T20:30:42.371604Z", "shell.execute_reply.started": "2021-04-02T20:30:42.328604Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
E01
B0101
D
05509152001523
11541126426351372
\n", "
" ], "text/plain": [ "E 0 1 \n", "B 0 1 0 1\n", "D \n", "0 550 915 200 1523\n", "1 1541 1264 2635 1372" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d_counts=pandas.crosstab(df['D'],[df['E'],df['B']])\n", "d_counts" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The same function can be applied, to determine conditional probabilities:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
E01
B0101
D
00.2630320.4199170.0705470.526079
10.7369680.5800830.9294530.473921
\n", "
" ], "text/plain": [ "E 0 1 \n", "B 0 1 0 1\n", "D \n", "0 0.263032 0.419917 0.070547 0.526079\n", "1 0.736968 0.580083 0.929453 0.473921" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d_condprob=pandas.crosstab(df['D'],[df['E'],df['B']],normalize=\"columns\")\n", "d_condprob" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## A global method for estimating Bayesian network parameters from CSV file using PANDAS" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "ExecuteTime": { "end_time": "2019-02-25T16:09:53.703200Z", "start_time": "2019-02-25T16:09:53.696660Z" }, "execution": { "iopub.execute_input": "2021-04-02T20:30:42.389604Z", "iopub.status.busy": "2021-04-02T20:30:42.389604Z", "iopub.status.idle": "2021-04-02T20:30:42.403602Z", "shell.execute_reply": "2021-04-02T20:30:42.403602Z", "shell.execute_reply.started": "2021-04-02T20:30:42.389604Z" } }, "outputs": [], "source": [ "def computeCPTfromDF(bn,df,name):\n", " \"\"\"\n", " Compute the CPT of variable \"name\" in the BN bn from the database df\n", " \"\"\"\n", " id=bn.idFromName(name)\n", " domains=[bn.variableFromName(name).domainSize() \n", " for name in bn.cpt(id).var_names]\n", "\n", " parents=list(bn.cpt(id).var_names)\n", " parents.pop()\n", " \n", " if (len(parents)>0):\n", " s=pandas.crosstab(df[name],[df[parent] for parent in parents],normalize=\"columns\")\n", " #s=c/c.sum().apply(np.float32)\n", " else:\n", " s=df[name].value_counts(normalize=True)\n", " \n", " bn.cpt(id)[:]=np.array((s).transpose()).reshape(*domains)\n", " \n", "def ParametersLearning(bn,df):\n", " \"\"\"\n", " Compute the CPTs of every varaible in the BN bn from the database df\n", " \"\"\"\n", " for name in bn.names():\n", " computeCPTfromDF(bn,df,name)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "ExecuteTime": { "end_time": "2019-02-25T16:09:53.831116Z", "start_time": "2019-02-25T16:09:53.706768Z" }, "execution": { "iopub.execute_input": "2021-04-02T20:30:42.405605Z", "iopub.status.busy": "2021-04-02T20:30:42.404603Z", "iopub.status.idle": "2021-04-02T20:30:42.468603Z", "shell.execute_reply": "2021-04-02T20:30:42.467604Z", "shell.execute_reply.started": "2021-04-02T20:30:42.405605Z" } }, "outputs": [], "source": [ "ParametersLearning(bn2,df)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "KL has decreased a lot (if everything's OK)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "ExecuteTime": { "end_time": "2019-02-25T16:09:53.837243Z", "start_time": "2019-02-25T16:09:53.833115Z" }, "execution": { "iopub.execute_input": "2021-04-02T20:30:42.469604Z", "iopub.status.busy": "2021-04-02T20:30:42.469604Z", "iopub.status.idle": "2021-04-02T20:30:42.483604Z", "shell.execute_reply": "2021-04-02T20:30:42.482604Z", "shell.execute_reply.started": "2021-04-02T20:30:42.469604Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "BEFORE LEARNING\n", "2.8515800881220987\n", "AFTER LEARNING\n", "0.6419513483372401\n" ] } ], "source": [ "g1=gum.ExactBNdistance(bn,bn2)\n", "print(\"BEFORE LEARNING\")\n", "print(before_learning['klPQ'])\n", "print\n", "print(\"AFTER LEARNING\")\n", "print(g1.compute()['klPQ'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And CPTs should be close" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "ExecuteTime": { "end_time": "2019-02-25T16:09:53.847007Z", "start_time": "2019-02-25T16:09:53.838694Z" }, "execution": { "iopub.execute_input": "2021-04-02T20:30:42.486604Z", "iopub.status.busy": "2021-04-02T20:30:42.486604Z", "iopub.status.idle": "2021-04-02T20:30:42.498604Z", "shell.execute_reply": "2021-04-02T20:30:42.497606Z", "shell.execute_reply.started": "2021-04-02T20:30:42.486604Z" }, "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "
L
S
0
1
0
0.33960.6604
1
0.43450.5655
\n", "\n", "\n", "\n", "\n", "
L
S
0
1
0
0.34870.6513
1
0.43530.5647
Original BN
learned BN
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "gnb.sideBySide(bn.cpt(3),\n", " bn2.cpt(3),\n", " captions=[\"Original BN\",\"learned BN\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Influence of the size of the database on the quality of learned parameters" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What is the effect of increasing the size of the database on the KL ? We expect that the KL decreases to 0." ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "ExecuteTime": { "end_time": "2019-02-25T16:10:12.870326Z", "start_time": "2019-02-25T16:09:53.849931Z" }, "execution": { "iopub.execute_input": "2021-04-02T20:30:42.501605Z", "iopub.status.busy": "2021-04-02T20:30:42.501605Z", "iopub.status.idle": "2021-04-02T20:30:52.112604Z", "shell.execute_reply": "2021-04-02T20:30:52.111603Z", "shell.execute_reply.started": "2021-04-02T20:30:42.501605Z" } }, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 1.0, 'klPQ(bn,learnedBN(x))')" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2021-04-23T08:48:36.925475\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.3.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "res=[]\n", "for i in range(200,10001,50):\n", " ParametersLearning(bn2,df[:i])\n", " g1=gum.ExactBNdistance(bn,bn2)\n", " res.append(g1.compute()['klPQ'])\n", "fig=figure(figsize=(8,5))\n", "plt.plot(range(200,10001,50),res)\n", "plt.xlabel(\"size of the database\")\n", "plt.ylabel(\"KL\")\n", "plt.title(\"klPQ(bn,learnedBN(x))\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "hide_input": false, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.0" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }