Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def get_embedding_vec_list(embedding_dict, input_dict, sparse_fg_list,return_feat_list=(),mask_feat_list=()):
embedding_vec_list = []
for fg in sparse_fg_list:
feat_name = fg.name
if len(return_feat_list) == 0 or feat_name in return_feat_list:
if fg.hash_flag:
lookup_idx = Hash(fg.dimension,mask_zero=(feat_name in mask_feat_list))(input_dict[feat_name])
else:
lookup_idx = input_dict[feat_name]
embedding_vec_list.append(embedding_dict[feat_name](lookup_idx))
return embedding_vec_list
def get_varlen_embedding_vec_dict(embedding_dict, sequence_input_dict, sequence_fg_list):
varlen_embedding_vec_dict = {}
for fg in sequence_fg_list:
feat_name = fg.name
if fg.hash_flag:
lookup_idx = Hash(fg.dimension, mask_zero=True)(sequence_input_dict[feat_name])
else:
lookup_idx = sequence_input_dict[feat_name]
varlen_embedding_vec_dict[feat_name] = embedding_dict[feat_name](lookup_idx)
return varlen_embedding_vec_dict