Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __create_stats(self):
"""
This private method is used to create Statistic Servers.
TODO: post some more info
"""
# Read tv_idmap
tv_idmap = sidekit.IdMap.read(os.path.join(self.BASE_DIR, "task", "tv_idmap.h5"))
back_idmap = tv_idmap
# If PLDA is enabled
if self.ENABLE_PLDA:
# Read plda_idmap
plda_idmap = sidekit.IdMap.read(os.path.join(self.BASE_DIR, "task", "plda_idmap.h5"))
# Create a joint StatServer for TV and PLDA training data
back_idmap = plda_idmap.merge(tv_idmap)
if not back_idmap.validate():
raise RuntimeError("Error merging tv_idmap & plda_idmap")
# Check UBM model
ubm_name = "ubm_{}.h5".format(self.NUM_GAUSSIANS)
ubm_path = os.path.join(self.BASE_DIR, "ubm", ubm_name)
if not os.path.exists(ubm_path):
#if UBM model does not exist, train one
logging.info("Training UBM-{} model".format(self.NUM_GAUSSIANS))
ubm = UBM(self.conf_path)
ubm.train()
#load trained UBM model
logging.info("Loading trained UBM-{} model".format(self.NUM_GAUSSIANS))
ubm = sidekit.Mixture()
def __create_stats(self):
"""
This private method is used to create Statistic Servers.
TODO: post some more info
"""
# Read tv_idmap
tv_idmap = sidekit.IdMap.read(os.path.join(self.BASE_DIR, "task", "tv_idmap.h5"))
back_idmap = tv_idmap
# If PLDA is enabled
if self.ENABLE_PLDA:
# Read plda_idmap
plda_idmap = sidekit.IdMap.read(os.path.join(self.BASE_DIR, "task", "plda_idmap.h5"))
# Create a joint StatServer for TV and PLDA training data
back_idmap = plda_idmap.merge(tv_idmap)
if not back_idmap.validate():
raise RuntimeError("Error merging tv_idmap & plda_idmap")
# Check UBM model
ubm_name = "ubm_{}.h5".format(self.NUM_GAUSSIANS)
ubm_path = os.path.join(self.BASE_DIR, "ubm", ubm_name)
if not os.path.exists(ubm_path):
#if UBM model does not exist, train one
logging.info("Training UBM-{} model".format(self.NUM_GAUSSIANS))
# -> 2 iterations of EM with 2 distributions
# -> 2 iterations of EM with 4 distributions
# -> 4 iterations of EM with 8 distributions
# -> 4 iterations of EM with 16 distributions
# -> 4 iterations of EM with 32 distributions
# -> 4 iterations of EM with 64 distributions
# -> 8 iterations of EM with 128 distributions
# -> 8 iterations of EM with 256 distributions
# -> 8 iterations of EM with 512 distributions
# -> 8 iterations of EM with 1024 distributions
model_dir = os.path.join(self.BASE_DIR, "ubm")
logging.info("Saving the model {} at {}".format(ubm.name, model_dir))
ubm.write(os.path.join(model_dir, ubm.name))
# Read idmap for the enrolling data
enroll_idmap = sidekit.IdMap.read(os.path.join(self.BASE_DIR, "task", "enroll_idmap.h5"))
# Create Statistic Server to store/process the enrollment data
enroll_stat = sidekit.StatServer(statserver_file_name=enroll_idmap,
ubm=ubm)
logging.debug(enroll_stat)
server.feature_filename_structure = os.path.join(self.BASE_DIR, "feat", "{}.h5")
# Compute the sufficient statistics for a list of sessions whose indices are segIndices.
#BUG: don't use self.NUM_THREADS when assgining num_thread as it's prune to race-conditioning
enroll_stat.accumulate_stat(ubm=ubm,
feature_server=server,
seg_indices=range(enroll_stat.segset.shape[0])
)
if SAVE:
# Save the status of the enroll data
filename = "enroll_stat_{}.h5".format(self.NUM_GAUSSIANS)
enroll_stat.write(os.path.join(self.BASE_DIR, "stat", filename))
tv_stat.write(os.path.join(self.BASE_DIR, "stat", tv_filename))
# Load sufficient statistics and extract i-vectors from PLDA training data
if self.ENABLE_PLDA:
plda_filename = 'plda_stat_{}.h5'.format(self.NUM_GAUSSIANS)
if not os.path.isfile(os.path.join(self.BASE_DIR, "stat", plda_filename)):
plda_stat = sidekit.StatServer.read_subset(
os.path.join(self.BASE_DIR, "stat", back_filename),
plda_idmap
)
plda_stat.write(os.path.join(self.BASE_DIR, "stat", plda_filename))
# Load sufficient statistics from test data
filename = 'test_stat_{}.h5'.format(self.NUM_GAUSSIANS)
if not os.path.isfile(os.path.join(self.BASE_DIR, "stat", filename)):
test_idmap = sidekit.IdMap.read(os.path.join(self.BASE_DIR, "task", "test_idmap.h5"))
test_stat = sidekit.StatServer( statserver_file_name=test_idmap,
ubm=ubm
)
# Create Feature Server
fs = self.createFeatureServer()
# Jointly compute the sufficient statistics of TV and PLDA data
#BUG: don't use self.NUM_THREADS when assgining num_thread as it's prune to race-conditioning
test_stat.accumulate_stat(ubm=ubm,
feature_server=fs,
seg_indices=range(test_stat.segset.shape[0])
)
test_stat.write(os.path.join(self.BASE_DIR, "stat", filename))