How to use the causalml.inference.tree.models.DecisionTree function in causalml

To help you get started, we’ve selected a few causalml examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github uber / causalml / causalml / inference / tree / models.py View on Github external
min_samples_treatment: int, optional (default=10)
            The minimum number of samples required of the experiment group to be split at a leaf node.
        n_reg: int, optional (default=10)
            The regularization parameter defined in Rzepakowski et al. 2012,
            the weight (in terms of sample size) of the parent node influence
            on the child node, only effective for 'KL', 'ED', 'Chi', 'CTS' methods.
        parentNodeSummary : dictionary, optional (default = None)
            Node summary statistics of the parent tree node.

        Returns
        -------
        object of DecisionTree class
        '''

        if len(X) == 0:
            return DecisionTree()

        # Current Node Info and Summary
        currentNodeSummary = self.tree_node_summary(treatment, y,
                                                    min_samples_treatment=min_samples_treatment,
                                                    n_reg=n_reg,
                                                    parentNodeSummary=parentNodeSummary)
        if evaluationFunction == self.evaluate_CTS:
            currentScore = evaluationFunction(currentNodeSummary)
        else:
            currentScore = evaluationFunction(currentNodeSummary, control_name=self.control_name)

        # Prune Stats
        maxAbsDiff = 0
        maxDiff = -1.
        bestTreatment = self.control_name
        suboptTreatment = self.control_name
github uber / causalml / causalml / inference / tree / models.py View on Github external
*best_set_right, evaluationFunction, max_depth, min_samples_leaf,
                depth + 1, min_samples_treatment=min_samples_treatment,
                n_reg=n_reg, parentNodeSummary=currentNodeSummary
            )

            return DecisionTree(
                col=bestAttribute[0], value=bestAttribute[1],
                trueBranch=trueBranch, falseBranch=falseBranch, summary=dcY,
                maxDiffTreatment=maxDiffTreatment, maxDiffSign=maxDiffSign,
                nodeSummary=currentNodeSummary,
                backupResults=self.uplift_classification_results(treatment, y),
                bestTreatment=bestTreatment, upliftScore=upliftScore
            )
        else:
            if evaluationFunction == self.evaluate_CTS:
                return DecisionTree(
                    results=self.uplift_classification_results(treatment, y),
                    summary=dcY, nodeSummary=currentNodeSummary,
                    bestTreatment=bestTreatment, upliftScore=upliftScore
                )
            else:
                return DecisionTree(
                    results=self.uplift_classification_results(treatment, y),
                    summary=dcY, maxDiffTreatment=maxDiffTreatment,
                    maxDiffSign=maxDiffSign, nodeSummary=currentNodeSummary,
                    bestTreatment=bestTreatment, upliftScore=upliftScore
                )
github uber / causalml / causalml / inference / tree / models.py View on Github external
dcY['upliftScore'] = [round(upliftScore[0], 4), round(upliftScore[1], 4)]
        dcY['matchScore'] = round(upliftScore[0], 4)

        if bestGain > 0 and depth < max_depth:
            trueBranch = self.growDecisionTreeFrom(
                *best_set_left, evaluationFunction, max_depth, min_samples_leaf,
                depth + 1, min_samples_treatment=min_samples_treatment,
                n_reg=n_reg, parentNodeSummary=currentNodeSummary
            )
            falseBranch = self.growDecisionTreeFrom(
                *best_set_right, evaluationFunction, max_depth, min_samples_leaf,
                depth + 1, min_samples_treatment=min_samples_treatment,
                n_reg=n_reg, parentNodeSummary=currentNodeSummary
            )

            return DecisionTree(
                col=bestAttribute[0], value=bestAttribute[1],
                trueBranch=trueBranch, falseBranch=falseBranch, summary=dcY,
                maxDiffTreatment=maxDiffTreatment, maxDiffSign=maxDiffSign,
                nodeSummary=currentNodeSummary,
                backupResults=self.uplift_classification_results(treatment, y),
                bestTreatment=bestTreatment, upliftScore=upliftScore
            )
        else:
            if evaluationFunction == self.evaluate_CTS:
                return DecisionTree(
                    results=self.uplift_classification_results(treatment, y),
                    summary=dcY, nodeSummary=currentNodeSummary,
                    bestTreatment=bestTreatment, upliftScore=upliftScore
                )
            else:
                return DecisionTree(
github uber / causalml / causalml / inference / tree / models.py View on Github external
col=bestAttribute[0], value=bestAttribute[1],
                trueBranch=trueBranch, falseBranch=falseBranch, summary=dcY,
                maxDiffTreatment=maxDiffTreatment, maxDiffSign=maxDiffSign,
                nodeSummary=currentNodeSummary,
                backupResults=self.uplift_classification_results(treatment, y),
                bestTreatment=bestTreatment, upliftScore=upliftScore
            )
        else:
            if evaluationFunction == self.evaluate_CTS:
                return DecisionTree(
                    results=self.uplift_classification_results(treatment, y),
                    summary=dcY, nodeSummary=currentNodeSummary,
                    bestTreatment=bestTreatment, upliftScore=upliftScore
                )
            else:
                return DecisionTree(
                    results=self.uplift_classification_results(treatment, y),
                    summary=dcY, maxDiffTreatment=maxDiffTreatment,
                    maxDiffSign=maxDiffSign, nodeSummary=currentNodeSummary,
                    bestTreatment=bestTreatment, upliftScore=upliftScore
                )