Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def main(config_name='config_infer.json'):
# K.clear_session()
with open(config_name) as f:
config = json.load(f)
# Reading datasets from files
reader_config = config['dataset_reader']
reader = REGISTRY[reader_config['name']]
data = reader.read(reader_config['data_path'])
# Building dict of datasets
dataset_config = config['dataset']
dataset = from_params(REGISTRY[dataset_config['name']],
dataset_config, data=data)
# Merging train and valid dataset for further split on train/valid
# dataset.merge_data(fields_to_merge=['train', 'valid'], new_field='train')
# dataset.split_data(field_to_split='train', new_fields=['train', 'valid'], proportions=[0.9, 0.1])
preproc_config = config['preprocessing']
preproc = from_params(REGISTRY[preproc_config['name']],
preproc_config)
# dataset = preproc.preprocess(dataset=dataset, data_type='train')
# dataset = preproc.preprocess(dataset=dataset, data_type='valid')
dataset = preproc.preprocess(dataset=dataset, data_type='test')
# Extracting unique classes
def fit_chainer(config: dict, iterator: BasicDatasetIterator) -> Chainer:
chainer_config: dict = config['chainer']
chainer = Chainer(chainer_config['in'], chainer_config['out'], chainer_config.get('in_y'))
for component_config in chainer_config['pipe']:
component = from_params(component_config, vocabs=[], mode='train')
if 'fit_on' in component_config:
component: Estimator
preprocessed = chainer(*iterator.iter_all('train'), to_return=component_config['fit_on'])
if len(component_config['fit_on']) == 1:
preprocessed = [preprocessed]
else:
preprocessed = zip(*preprocessed)
component.fit(*preprocessed)
component.save()
if 'in' in component_config:
c_in = component_config['in']
c_out = component_config['out']
in_y = component_config.get('in_y', None)
main = component_config.get('main', False)
kwargs = {k: v for k, v in reader_config.items() if k not in ['name', 'data_path']}
data = reader.read(data_path, **kwargs)
iterator_config = config['dataset_iterator']
iterator: BasicDatasetIterator = from_params(iterator_config, data=data)
if 'chainer' in config:
model = fit_chainer(config, iterator)
else:
vocabs = config.get('vocabs', {})
for vocab_param_name, vocab_config in vocabs.items():
v: Estimator = from_params(vocab_config, mode='train')
vocabs[vocab_param_name] = _fit(v, iterator)
model_config = config['model']
model = from_params(model_config, vocabs=vocabs, mode='train')
train_config = {
'metrics': ['accuracy'],
'validate_best': True,
'test_best': True
}
try:
train_config.update(config['train'])
except KeyError:
log.warning('Train config is missing. Populating with default values')
metrics_functions = list(zip(train_config['metrics'],
get_metrics_by_names(train_config['metrics'])))
module_name, cls_name = c.split(':')
reader = getattr(importlib.import_module(module_name), cls_name)()
except ValueError:
e = ConfigError('Expected class description in a `module.submodules:ClassName` form, but got `{}`'
.format(c))
log.exception(e)
raise e
else:
reader = get_model(reader_config.pop('name'))()
data_path = expand_path(reader_config.pop('data_path', ''))
data = reader.read(data_path, **reader_config)
else:
log.warning("No dataset reader is provided in the JSON config.")
iterator_config = config['dataset_iterator']
iterator: Union[DataLearningIterator, DataFittingIterator] = from_params(iterator_config,
data=data)
train_config = {
'metrics': ['accuracy'],
'validate_best': to_validate,
'test_best': True
}
try:
train_config.update(config['train'])
except KeyError:
log.warning('Train config is missing. Populating with default values')
metrics_functions = list(zip(train_config['metrics'], get_metrics_by_names(train_config['metrics'])))
if to_train:
reader_config = config['dataset_reader']
reader = get_model(reader_config['name'])()
data_path = expand_path(reader_config.get('data_path', ''))
kwargs = {k: v for k, v in reader_config.items() if k not in ['name', 'data_path']}
data = reader.read(data_path, **kwargs)
iterator_config = config['dataset_iterator']
iterator: BasicDatasetIterator = from_params(iterator_config, data=data)
if 'chainer' in config:
model = fit_chainer(config, iterator)
else:
vocabs = config.get('vocabs', {})
for vocab_param_name, vocab_config in vocabs.items():
v: Estimator = from_params(vocab_config, mode='train')
vocabs[vocab_param_name] = _fit(v, iterator)
model_config = config['model']
model = from_params(model_config, vocabs=vocabs, mode='train')
train_config = {
'metrics': ['accuracy'],
'validate_best': True,
'test_best': True
}
try:
train_config.update(config['train'])
except KeyError:
log.warning('Train config is missing. Populating with default values')
def main(config_name='config.json'):
with open(config_name) as f:
config = json.load(f)
# Reading datasets from files
reader_config = config['dataset_reader']
reader = _REGISTRY[reader_config['name']]
data = reader.read(train_data_path=reader_config.get('train_data_path'),
valid_data_path=reader_config.get('valid_data_path'),
test_data_path=reader_config.get('test_data_path'))
# Building dict of datasets
dataset_config = config['dataset']
dataset = from_params(_REGISTRY[dataset_config['name']],
dataset_config, data=data)
# Merging train and valid dataset for further split on train/valid
dataset.merge_data(fields_to_merge=['train', 'valid'], new_field='train')
dataset.split_data(field_to_split='train', new_fields=['train', 'valid'], proportions=[0.9, 0.1])
# Extracting unique classes
intents = dataset.extract_classes()
print("Considered intents:", intents)
# Initializing model
model_config = config['model']
model = from_params(_REGISTRY[model_config['name']],
model_config, opt=model_config, classes=intents)
print("Network parameters: ", model.network_params)
def train(config_path, usr_dir):
config = read_json(config_path)
model_config = config['model']
model_name = model_config['name']
# Path for models should be specified here:
model = from_params(_REGISTRY[model_name], model_config, models_path=usr_dir)
reader_config = config['dataset_reader']
reader = _REGISTRY[reader_config['name']]
data = reader.read(reader_config.get('data_path', usr_dir))
dataset_config = config['dataset']
dataset_name = dataset_config['name']
dataset = from_params(_REGISTRY[dataset_name], dataset_config, data=data)
model.train(dataset.iter_all())
model.save()
# Building dict of datasets
dataset_config = config['dataset']
dataset = from_params(_REGISTRY[dataset_config['name']],
dataset_config, data=data)
# Merging train and valid dataset for further split on train/valid
dataset.merge_data(fields_to_merge=['train', 'valid'], new_field='train')
dataset.split_data(field_to_split='train', new_fields=['train', 'valid'], proportions=[0.9, 0.1])
# Extracting unique classes
intents = dataset.extract_classes()
print("Considered intents:", intents)
# Initializing model
model_config = config['model']
model = from_params(_REGISTRY[model_config['name']],
model_config, opt=model_config, classes=intents)
print("Network parameters: ", model.network_params)
print("Learning parameters:", model.learning_params)
test_batch_gen = dataset.batch_generator(batch_size=model.learning_params['batch_size'],
data_type='test')
test_preds = []
test_true = []
for test_id, test_batch in enumerate(test_batch_gen):
test_preds.extend(model.infer(test_batch[0]))
test_true.extend(model.labels2onehot(test_batch[1]))
if model_config['show_examples'] and test_id == 0:
for j in range(model.learning_params['batch_size']):
print(test_batch[0][j],
# Building dict of datasets
dataset_config = config['dataset']
dataset = from_params(_REGISTRY[dataset_config['name']],
dataset_config, data=data)
# Merging train and valid dataset for further split on train/valid
dataset.merge_data(fields_to_merge=['train', 'valid'], new_field='train')
dataset.split_data(field_to_split='train', new_fields=['train', 'valid'], proportions=[0.9, 0.1])
# Extracting unique classes
intents = dataset.extract_classes()
print("Considered intents:", intents)
# Initializing model
model_config = config['model']
model = from_params(_REGISTRY[model_config['name']],
model_config, opt=model_config, classes=intents)
print("Network parameters: ", model.network_params)
print("Learning parameters:", model.learning_params)
print("Considered:", model.metrics_names)
if 'valid' in data.keys():
print('___Validation set is given___')
elif 'val_split' in model.learning_params.keys():
print('___Validation split is given___')
else:
print('___Validation set and validation split are not given.____\n____Validation split = 0.1____')
model.learning_params['val_split'] = 0.1
updates = 0
val_loss = 1e100