Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def get_class_object(method_name, *args, **kwargs):
# from https://www.bnmetrics.com/blog/factory-pattern-in-python3-simple-version
try:
module_name = method_name
class_name = string.capwords(method_name, "_").replace('_', '')
do_sampler_module = import_module('.' + module_name, package=PACKAGE_NAME)
do_sampler_class = getattr(do_sampler_module, class_name)
assert issubclass(do_sampler_class, DoSampler)
except (AttributeError, AssertionError, ImportError) as e:
if isinstance(e, ImportError) and e.name != PACKAGE_NAME + '.' + module_name:
raise e
raise ImportError('{} is not an existing do sampler.'.format(method_name))
return do_sampler_class
from dowhy.do_sampler import DoSampler
from dowhy.utils.propensity_score import propensity_of_treatment_score
class WeightingSampler(DoSampler):
def __init__(self, data,
*args, params=None,
variable_types=None, num_cores=1, keep_original_treatment=False,
causal_model=None, **kwargs):
"""
g, df, data_types
"""
super().__init__(data,
params=params,
variable_types=variable_types, num_cores=num_cores,
keep_original_treatment=keep_original_treatment,
causal_model=causal_model)
self.logger.info("Using WeightingSampler for do sampling.")
self.logger.info("Caution: do samplers assume iid data.")
from dowhy.do_sampler import DoSampler
import numpy as np
import pymc3 as pm
import networkx as nx
class McmcSampler(DoSampler):
def __init__(self, data, *args, params=None, variable_types=None,
num_cores=1, keep_original_treatment=False,
causal_model=None,
**kwargs):
"""
g, df, data_types
"""
super().__init__(data, params=params, variable_types=variable_types, causal_model=causal_model,
num_cores=num_cores, keep_original_treatment=keep_original_treatment)
self.logger.info("Using McmcSampler for do sampling.")
self.point_sampler = False
self.sampler = self._construct_sampler()
self.g = causal_model._graph.get_unconfounded_observed_subgraph()
from dowhy.do_sampler import DoSampler
from dowhy.utils.propensity_score import state_propensity_score
class MultivariateWeightingSampler(DoSampler):
def __init__(self, data, identified_estimand, treatments, outcomes, *args, params=None, variable_types=None,
num_cores=1, keep_original_treatment=False, **kwargs):
"""
g, df, data_types
"""
super().__init__(data, identified_estimand, treatments, outcomes, params=params,
variable_types=variable_types, num_cores=num_cores,
keep_original_treatment=keep_original_treatment)
self.logger.info("Using MultivariateWeightingSampler for do sampling.")
self.logger.info("Caution: do samplers assume iid data.")
self.point_sampler = False
def make_treatment_effective(self, x):