Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
pretrained = not model_parameters and not pretrained_bert_parameters and not args.sentencepiece
bert, vocab = nlp.model.get_model(
name=model_name,
dataset_name=dataset_name,
vocab=vocab,
pretrained=pretrained,
ctx=ctx,
use_pooler=False,
use_decoder=False,
use_classifier=False)
if args.sentencepiece:
tokenizer = nlp.data.BERTSPTokenizer(args.sentencepiece, vocab, lower=lower)
else:
tokenizer = nlp.data.BERTTokenizer(vocab=vocab, lower=lower)
batchify_fn = nlp.data.batchify.Tuple(
nlp.data.batchify.Stack(),
nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token]),
nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token]),
nlp.data.batchify.Stack('float32'),
nlp.data.batchify.Stack('float32'),
nlp.data.batchify.Stack('float32'))
net = BertForQA(bert=bert)
if model_parameters:
# load complete BertForQA parameters
net.load_parameters(model_parameters, ctx=ctx, cast_dtype=True)
elif pretrained_bert_parameters:
# only load BertModel parameters
bert.load_parameters(pretrained_bert_parameters, ctx=ctx,
if pretrained_bert_parameters:
logging.info('loading bert params from %s', pretrained_bert_parameters)
model.bert.load_parameters(pretrained_bert_parameters, ctx=ctx,
ignore_extra=True)
if model_parameters:
logging.info('loading model params from %s', model_parameters)
model.load_parameters(model_parameters, ctx=ctx)
nlp.utils.mkdir(output_dir)
logging.debug(model)
model.hybridize(static_alloc=True)
loss_function.hybridize(static_alloc=True)
# data processing
do_lower_case = 'uncased' in dataset
bert_tokenizer = BERTTokenizer(vocabulary, lower=do_lower_case)
def preprocess_data(tokenizer, task, batch_size, dev_batch_size, max_len, pad=False):
"""Train/eval Data preparation function."""
pool = multiprocessing.Pool()
# transformation for data train and dev
label_dtype = 'float32' if not task.class_labels else 'int32'
trans = BERTDatasetTransform(tokenizer, max_len,
class_labels=task.class_labels,
label_alias=task.label_alias,
pad=pad, pair=task.is_pair,
has_label=True)
# data train
# task.dataset_train returns (segment_name, dataset)
train_tsv = task.dataset_train()[1]
def train(args):
ctx = mx.cpu() if args.gpu is None else mx.gpu(args.gpu)
dataset_name = 'book_corpus_wiki_en_cased' if args.cased else 'book_corpus_wiki_en_uncased'
bert_model, bert_vocab = nlp.model.get_model(name=args.bert_model,
dataset_name=dataset_name,
pretrained=True,
ctx=ctx,
use_pooler=True,
use_decoder=False,
use_classifier=False,
dropout=args.dropout_prob,
embed_dropout=args.dropout_prob)
tokenizer = BERTTokenizer(bert_vocab, lower=not args.cased)
if args.dataset == 'atis':
train_data = ATISDataset('train')
dev_data = ATISDataset('dev')
test_data = ATISDataset('test')
intent_vocab = train_data.intent_vocab
slot_vocab = train_data.slot_vocab
elif args.dataset == 'snips':
train_data = SNIPSDataset('train')
dev_data = SNIPSDataset('dev')
test_data = SNIPSDataset('test')
intent_vocab = train_data.intent_vocab
slot_vocab = train_data.slot_vocab
else:
raise NotImplementedError
print('Dataset {}'.format(args.dataset))
print(' #Train/Dev/Test = {}/{}/{}'.format(len(train_data), len(dev_data), len(test_data)))
vocab=vocab,
pretrained=not args.gluon_parameter_file,
use_pooler=False,
use_decoder=False,
use_classifier=False)
if args.gluon_parameter_file:
try:
bert.cast('float16')
bert.load_parameters(args.gluon_parameter_file, ignore_extra=True)
bert.cast('float32')
except AssertionError:
bert.cast('float32')
bert.load_parameters(args.gluon_parameter_file, ignore_extra=True)
print(bert)
tokenizer = nlp.data.BERTTokenizer(vocabulary, lower=do_lower_case)
dataset = nlp.data.TSVDataset(input_file, field_separator=nlp.data.Splitter('|||'))
trans = nlp.data.BERTSentenceTransform(tokenizer, max_length)
dataset = dataset.transform(trans)
bert_dataloader = mx.gluon.data.DataLoader(dataset, batch_size=1,
shuffle=True, last_batch='rollover')
# verify the output of the first sample
for i, seq in enumerate(bert_dataloader):
input_ids, valid_length, type_ids = seq
out = bert(input_ids, type_ids,
valid_length.astype('float32'))
length = valid_length.asscalar()
gluon_np = out.asnumpy().squeeze(0)
print(out)
bert.cast('float32')
except AssertionError:
bert.cast('float32')
bert.load_parameters(args.gluon_parameter_file, ignore_extra=True)
else:
assert not args.gluon_vocab_file, \
'Cannot specify --gluon_vocab_file without specifying --gluon_parameter_file'
bert, vocabulary = nlp.model.get_model(args.gluon_model,
dataset_name=args.gluon_dataset,
pretrained=not args.gluon_parameter_file,
use_pooler=False,
use_decoder=False,
use_classifier=False)
print(bert)
tokenizer = nlp.data.BERTTokenizer(vocabulary, lower=do_lower_case)
dataset = nlp.data.TSVDataset(input_file, field_separator=nlp.data.Splitter(' ||| '))
trans = nlp.data.BERTSentenceTransform(tokenizer, max_length)
dataset = dataset.transform(trans)
bert_dataloader = mx.gluon.data.DataLoader(dataset, batch_size=1,
shuffle=True, last_batch='rollover')
# verify the output of the first sample
for i, seq in enumerate(bert_dataloader):
input_ids, valid_length, type_ids = seq
out = bert(input_ids, type_ids,
valid_length.astype('float32'))
length = valid_length.asscalar()
a = tf_outputs[-1][:length]
b = out[0][:length].asnumpy()
def predict(dataset, all_results, vocab):
tokenizer = nlp.data.BERTTokenizer(vocab=vocab, lower=True)
transform = bert.data.qa.SQuADTransform(tokenizer, is_pad=False, is_training=False, do_lookup=False)
dev_dataset = dataset.transform(transform._transform)
from bert.bert_qa_evaluate import PredResult, predict
all_results_np = collections.defaultdict(list)
for example_ids, pred_start, pred_end in all_results:
batch_size = example_ids.shape[0]
example_ids = example_ids.asnumpy().tolist()
pred_start = pred_start.reshape(batch_size, -1).asnumpy()
pred_end = pred_end.reshape(batch_size, -1).asnumpy()
for example_id, start, end in zip(example_ids, pred_start, pred_end):
all_results_np[example_id].append(PredResult(start=start, end=end))
all_predictions = collections.OrderedDict()
top_results = []
pretrained = not model_parameters and not pretrained_bert_parameters and not args.sentencepiece
bert, vocab = nlp.model.get_model(
name=model_name,
dataset_name=dataset_name,
vocab=vocab,
pretrained=pretrained,
ctx=ctx,
use_pooler=False,
use_decoder=False,
use_classifier=False)
if args.sentencepiece:
tokenizer = nlp.data.BERTSPTokenizer(args.sentencepiece, vocab, lower=lower)
else:
tokenizer = nlp.data.BERTTokenizer(vocab=vocab, lower=lower)
batchify_fn = nlp.data.batchify.Tuple(
nlp.data.batchify.Stack(),
nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token], round_to=args.round_to),
nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token], round_to=args.round_to),
nlp.data.batchify.Stack('float32'),
nlp.data.batchify.Stack('float32'),
nlp.data.batchify.Stack('float32'))
# load symbolic model
deploy = args.deploy
model_prefix = args.model_prefix
net = BertForQA(bert=bert)
if model_parameters:
# load complete BertForQA parameters
dataset_name = None
vocab = nlp.vocab.BERTVocab.from_sentencepiece(args.sentencepiece)
model, vocab = get_model_loss(ctxs, args.model, args.pretrained,
dataset_name, vocab, args.dtype,
ckpt_dir=args.ckpt_dir,
start_step=args.start_step)
logging.info('Model created')
data_eval = args.data_eval
if args.raw:
if args.sentencepiece:
tokenizer = nlp.data.BERTSPTokenizer(args.sentencepiece, vocab,
lower=not args.cased)
else:
tokenizer = nlp.data.BERTTokenizer(vocab=vocab, lower=not args.cased)
cache_dir = os.path.join(args.ckpt_dir, 'data_eval_cache')
cache_file = os.path.join(cache_dir, 'part-000.npz')
nlp.utils.mkdir(cache_dir)
# generate dev dataset from the raw text if needed
if not args.eval_use_npz:
data_eval = cache_file
if not os.path.isfile(cache_file) and rank == 0:
generate_dev_set(tokenizer, vocab, cache_file, args)
logging.debug('Random seed set to %d', random_seed)
mx.random.seed(random_seed)
if args.data:
if args.raw:
vocab = nlp.vocab.BERTVocab.from_sentencepiece(args.sentencepiece)
model, nsp_loss, mlm_loss, vocab = get_model_loss([ctx], args.model, args.pretrained,
dataset_name, vocab, args.dtype,
ckpt_dir=args.ckpt_dir,
start_step=args.start_step)
logging.debug('Model created')
data_eval = args.data_eval
if args.raw:
if args.sentencepiece:
tokenizer = nlp.data.BERTSPTokenizer(args.sentencepiece, vocab,
num_best=args.sp_nbest,
alpha=args.sp_alpha, lower=not args.cased)
else:
tokenizer = nlp.data.BERTTokenizer(vocab=vocab, lower=not args.cased)
cache_dir = os.path.join(args.ckpt_dir, 'data_eval_cache')
cache_file = os.path.join(cache_dir, 'part-000.npz')
nlp.utils.mkdir(cache_dir)
# generate dev dataset from the raw text if needed
if not args.eval_use_npz:
data_eval = cache_file
if not os.path.isfile(cache_file) and rank == 0:
generate_dev_set(tokenizer, vocab, cache_file, args)
logging.debug('Random seed set to %d', random_seed)
mx.random.seed(random_seed)
if args.data:
if args.raw: