# Source code for rfgb.rdn.learn

# -*- coding: utf-8 -*-

#
# This program is free software: you can redistribute it and/or modify
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program (at the base of this repository). If not,

from __future__ import print_function
from __future__ import absolute_import

from ..tree import node
from ..utils import Utils

[docs]def learn(
targets,
numTrees=10,
path="",
regression=False,
softm=False,
alpha=0.0,
beta=0.0,
saveJson=True,
):
"""

Learn a relational dependency network from facts and positive/negative
examples via relational regression trees.

.. note:: This currently requires that training data is stored as files
on disk.

:param targets: List of target predicates to learn models for.
:type targets: list of str.

:param numTrees: Number of trees to learn.
:type numTrees: int.

:param path: Path to the location training data is stored.
:type path: str.

:param regression: Learn a regression model instead of classification.
:type regression: bool.

:default regression: False

:returns: Dictionary where the key is the target and the value is the
set of trees returned for that target.
:rtype: dict.
"""

# Models will be returned as a dictionary, where the name of the predicate
# will be bound to the set of trees learned for it.
models = {}

for target in targets:

target,
path=path,
regression=regression,
softm=softm,
alpha=alpha,
beta=beta,
)

# Initialize an empty list for the trees.
trees = []

# Learn each tree and update the gradients.
for i in range(numTrees):

node.setMaxDepth(2)
node.learnTree(trainData)
trees.append(node.learnedDecisionTree)

# Save the models learned at this step.
if saveJson:

# Collect the parameters used to learn these trees:
params = {
"target": target,
"trees": i + 1,
"regression": regression,