Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def get_varlen_pooling_list(embedding_dict, features, varlen_sparse_feature_columns, to_list=False):
pooling_vec_list = defaultdict(list)
for fc in varlen_sparse_feature_columns:
feature_name = fc.name
combiner = fc.combiner
feature_length_name = fc.length_name
if feature_length_name is not None:
if fc.weight_name is not None:
seq_input = WeightedSequenceLayer()(
[embedding_dict[feature_name], features[feature_length_name], features[fc.weight_name]])
else:
seq_input = embedding_dict[feature_name]
vec = SequencePoolingLayer(combiner, supports_masking=False)(
[seq_input, features[feature_length_name]])
else:
if fc.weight_name is not None:
seq_input = WeightedSequenceLayer(supports_masking=True)(
[embedding_dict[feature_name], features[fc.weight_name]])
else:
seq_input = embedding_dict[feature_name]
vec = SequencePoolingLayer(combiner, supports_masking=True)(
seq_input)
pooling_vec_list[fc.group_name].append(vec)
if to_list:
return chain.from_iterable(pooling_vec_list.values())
pooling_vec_list = defaultdict(list)
for fc in varlen_sparse_feature_columns:
feature_name = fc.name
combiner = fc.combiner
feature_length_name = fc.length_name
if feature_length_name is not None:
if fc.weight_name is not None:
seq_input = WeightedSequenceLayer()(
[embedding_dict[feature_name], features[feature_length_name], features[fc.weight_name]])
else:
seq_input = embedding_dict[feature_name]
vec = SequencePoolingLayer(combiner, supports_masking=False)(
[seq_input, features[feature_length_name]])
else:
if fc.weight_name is not None:
seq_input = WeightedSequenceLayer(supports_masking=True)(
[embedding_dict[feature_name], features[fc.weight_name]])
else:
seq_input = embedding_dict[feature_name]
vec = SequencePoolingLayer(combiner, supports_masking=True)(
seq_input)
pooling_vec_list[fc.group_name].append(vec)
if to_list:
return chain.from_iterable(pooling_vec_list.values())
return pooling_vec_list
def get_config(self, ):
config = {'weight_normalization':self.weight_normalization,'supports_masking': self.supports_masking}
base_config = super(WeightedSequenceLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))