Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
with maximal sampling. Higher values mean higher weight.
gamma: float
Governs the scaling of the weight of the max samples as a function
of the % of papers read. Higher values mean stronger scaling.
"""
super(TripleBalance, self).__init__()
self.a = a
self.alpha = alpha
self.b = b
self.beta = beta
self.c = c
self.gamma = gamma
self.shuffle = shuffle
self.fallback_model = DoubleBalance(a=a, alpha=alpha, b=b, beta=beta,
random_state=random_state)
self._random_state = get_random_state(random_state)
def __init__(self, random_state=None):
super(RandomQuery, self).__init__()
self._random_state = get_random_state(random_state)
def __init__(self, *args, embedding_fp=None, random_state=None, **kwargs):
"""Initialize the Embedding-Idf model
Arguments
---------
embedding_fp: str
Path to embedding.
"""
super(EmbeddingIdf, self).__init__(*args, **kwargs)
self.embedding_fp = embedding_fp
self.embedding = None
self._random_state = get_random_state(random_state)
new_key = key[len(strategy_1)+1:]
kwargs_1[new_key] = value
elif key.starts_with(strategy_2):
new_key = key[len(strategy_2)+1:]
kwargs_2[new_key] = value
else:
logging.warn(f"Key {key} is being ignored for the mixed "
"({strategy_1}, {strategy_2}) query strategy.")
self.strategy_1 = strategy_1
self.strategy_2 = strategy_2
self.query_model1 = get_query_model(strategy_1, **kwargs_1)
self.query_model2 = get_query_model(strategy_2, **kwargs_2)
self._random_state = get_random_state(random_state)
if "random_state" in self.query_model1.default_param:
self.query_model1 = get_query_model(strategy_1, **kwargs_1,
random_state=self._random_state
)
if "random_state" in self.query_model2.default_param:
self.query_model2 = get_query_model(strategy_2, **kwargs_2,
random_state=self._random_state
)
self.mix_ratio = mix_ratio
if query_param is not None:
settings.query_param = query_param
if balance_param is not None:
settings.balance_param = balance_param
if feature_param is not None:
settings.feature_param = feature_param
# Check if mode is valid
if mode in AVAILABLE_REVIEW_CLASSES:
logging.info(f"Start review in '{mode}' mode.")
else:
raise ValueError(f"Unknown mode '{mode}'.")
logging.debug(settings)
# Initialize models.
random_state = get_random_state(seed)
train_model = get_model(settings.model, **settings.model_param,
random_state=random_state)
query_model = get_query_model(settings.query_strategy,
**settings.query_param,
random_state=random_state)
balance_model = get_balance_model(settings.balance_strategy,
**settings.balance_param,
random_state=random_state)
feature_model = get_feature_model(settings.feature_extraction,
**settings.feature_param,
random_state=random_state)
# LSTM models need embedding matrices.
if train_model.name.startswith("lstm-"):
texts = as_data.texts
train_model.embedding_matrix = feature_model.get_embedding_matrix(
def __init__(self, a=2.155, alpha=0.94, b=0.789, beta=1.0,
random_state=None):
super(DoubleBalance, self).__init__()
self.a = a
self.alpha = alpha
self.b = b
self.beta = beta
self.fallback_model = SimpleBalance()
self._random_state = get_random_state(random_state)
def __init__(self, ratio=1.0, random_state=None):
"""Initialize the undersampling balance strategy.
Arguments
---------
ratio: double
Undersampling ratio of the zero's. If for example we set a ratio of
0.25, we would sample only a quarter of the zeros and all the ones.
"""
super(UndersampleBalance, self).__init__()
self.ratio = ratio
self._random_state = get_random_state(random_state)
Arguments
---------
cluster_size: int
Size of the clusters to be made. If the size of the clusters is
smaller than the size of the pool, fall back to max sampling.
update_interval: int
Update the clustering every x instances.
random_state: int, RandomState
State/seed of the RNG.
"""
super(ClusterQuery, self).__init__()
self.cluster_size = cluster_size
self.update_interval = update_interval
self.last_update = None
self.fallback_model = MaxQuery()
self._random_state = get_random_state(random_state)