# -*- coding: utf-8 -*-
# Copyright © 2017-2019 rfgb Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program (at the base of this repository). If not,
# see <http://www.gnu.org/licenses/>
"""
Data structures and methods for learning decision trees.
"""
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
from .utils import Utils
from .logic import Logic
from .logic import Prover
from copy import deepcopy
[docs]class node:
"""
A node in a tree.
:param expandQueue: Breadth first search node expansion strategy
:param depth: initial depth is 0 because no node present
:param maxDepth: max depth set to 1 because we want to at least learn a tree of depth 1
:param learnedDecisionTree: this will hold all the clauses learned
:param data: stores all the facts, positive and negative examples
"""
expandQueue = []
depth = 0
maxDepth = 1
learnedDecisionTree = []
data = None
def __init__(
self,
test=None,
examples=None,
information=None,
level=None,
parent=None,
pos=None,
):
"""
Constructor for node class.
:param test: Test condition in the form of a horn clause.
:param examples: Examples available for testing at this node.
:param information: Information contained at this node.
:param level: Level of this node in the tree (0 for root).
:param parent: "root", or a pointer to the parent.
:param pos: Position in the tree ('left' or 'right')
"""
self.test = test
if level > 0:
self.parent = parent
else:
self.parent = "root"
self.pos = pos
self.examples = examples
self.information = information
self.level = level
self.left = None
self.right = None
# Add to the queue of nodes to expand.
node.expandQueue.insert(0, self)
[docs] @staticmethod
def setMaxDepth(depth):
"""
Set the maximum depth of the tree.
"""
node.maxDepth = depth
[docs] @staticmethod
def initTree(trainingData):
"""
Create the root node of the tree.
"""
node.data = trainingData
# Reset node queue for every tree to be learned.
node.expandQueue = []
# Reset clauses for every tree to be learned.
node.learnedDecisionTree = []
if trainingData.regression:
# Regression examples can be collected from trainingData.examples
# (since there are no pos/neg).
examples = trainingData.examples.keys()
else:
# For all other models, we consider a set of positive and
# negative examples.
examples = list(trainingData.pos.keys()) + list(trainingData.neg.keys())
node(
test=None,
examples=examples,
information=trainingData.variance(examples),
level=0,
parent="root",
)
[docs] @staticmethod
def learnTree(data):
"""
Method to create and learn the decision tree.
"""
# Create the root
node.initTree(data)
while len(node.expandQueue) > 0:
current = node.expandQueue.pop()
current.expandOnBestTest(data)
node.learnedDecisionTree.sort(key=lambda x: len(x.split(" ")[0]))
node.learnedDecisionTree = node.learnedDecisionTree[::-1]
[docs] def getTrueExamples(self, clause, test, data):
"""
Returns all examples that satisfy the clause
with conjoined test literal.
"""
# Initialize a list of true examples.
trueExamples = []
clauseCopy = deepcopy(clause)
# Construct clause for prover
if clauseCopy[-1] == "-":
clauseCopy += test
elif clauseCopy[-1] == ";":
clauseCopy = clauseCopy.replace(";", ",") + test
# Prove if example satisfies clause.
for example in self.examples:
if Prover.prove(data, example, clauseCopy):
trueExamples.append(example)
return trueExamples
[docs] def expandOnBestTest(self, data=None):
"""
Expand the node based on the best test.
"""
target = data.getTarget()
# Initialize clause learned at this node with empty body.
clause = target + ":-"
current = self
ancestorTests = []
while current.parent != "root":
if current.pos == "left":
clause += current.parent.test + ";"
ancestorTests.append(current.parent.test)
elif current.pos == "right":
ancestorTests.append(current.parent.test)
current = current.parent
if self.level == node.maxDepth or round(self.information, 3) == 0:
if clause[-1] != "-":
node.learnedDecisionTree.append(
clause[:-1] + " " + str(Utils.getleafValue(self.examples))
)
else:
node.learnedDecisionTree.append(
clause + " " + str(Utils.getleafValue(self.examples))
)
return
if clause[-2] == "-":
clause = clause[:-1]
# Initialize minimum weighted variance to a low value.
minScore = float("inf")
bestTest = ""
# List for best test examples which satisfy or do not satisfy clause.
bestTExamples, bestFExamples = [], []
# Get all the literals contained in the facts.
literals = data.getLiterals()
tests = []
# For every literal generate test conditions.
for literal in literals:
literalName = literal[0]
literalTypeSpecification = literal[1]
# Generate all possible literal, variable, and constant combinations
tests += Logic.generateTests(literalName, literalTypeSpecification, clause)
if self.parent != "root":
tests = [test for test in tests if not test in ancestorTests]
tests = set(tests)
# Check which test scores the best.
for test in tests:
# Examples which are satisfied.
tExamples = self.getTrueExamples(clause, test, data)
# Examples which are not satisfied (under closed world assumption).
fExamples = [
example for example in self.examples if example not in tExamples
]
# Total number of examples.
example_len = len(self.examples)
# Calculated the weighted variance:
score = (len(tExamples) / example_len) * data.variance(tExamples) + (
len(fExamples) / example_len
) * data.variance(fExamples)
if score < minScore: # if score lower than current lowest
minScore = score # assign new minimum
bestTest = test # assign new best test
bestTExamples = tExamples # collect satisfied examples
bestFExamples = fExamples # collect unsatisfied examples
Utils.addVariableTypes(bestTest) # add variable types of new variables
self.test = bestTest # assign best test after going through all literal specs
print("Best test found at the current node: ", self.test)
# If True examples need further explaining,
# create left node and add to the queue.
if len(bestTExamples) > 0:
self.left = node(
test=None,
examples=bestTExamples,
information=data.variance(bestTExamples),
level=self.level + 1,
parent=self,
pos="left",
)
if self.level + 1 > node.depth:
node.depth = self.level + 1
# If False examples need further explaining,
# create right node and add to the queue.
if len(bestFExamples) > 0:
self.right = node(
test=None,
examples=bestFExamples,
information=data.variance(bestFExamples),
level=self.level + 1,
parent=self,
pos="right",
)
if self.level + 1 > node.depth:
node.depth = self.level + 1
# If there are no examples, append clause as it is.
# if no examples append clause as is
if self.test == "" or round(self.information, 3) == 0:
if clause[-1] != "-":
node.learnedDecisionTree.append(clause[:-1])
else:
node.learnedDecisionTree.append(clause)
return