Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def get_estimator(self):
from tensorflow.python.estimator.estimator import Estimator
from tensorflow.python.estimator.run_config import RunConfig
bert_config = modeling.BertConfig.from_json_file(self.config_name)
label_list = self.processor.get_labels()
train_examples = self.processor.get_train_examples(self.data_dir)
num_train_steps = int(
len(train_examples) / self.batch_size * self.num_train_epochs)
num_warmup_steps = int(num_train_steps * 0.1)
if self.mode == tf.estimator.ModeKeys.TRAIN:
init_checkpoint = self.ckpt_name
else:
init_checkpoint = self.output_dir
model_fn = self.model_fn_builder(
bert_config=bert_config,
num_labels=len(label_list),
init_checkpoint=init_checkpoint,
learning_rate=self.learning_rate,
def train(self):
if self.mode is None:
raise ValueError("Please set the 'mode' parameter")
bert_config = modeling.BertConfig.from_json_file(self.config_name)
if self.max_seq_len > bert_config.max_position_embeddings:
raise ValueError(
"Cannot use sequence length %d because the BERT model "
"was only trained up to sequence length %d" %
(self.max_seq_len, bert_config.max_position_embeddings))
tf.gfile.MakeDirs(self.output_dir)
label_list = self.processor.get_labels()
train_examples = self.processor.get_train_examples(self.data_dir)
num_train_steps = int(len(train_examples) / self.batch_size * self.num_train_epochs)
estimator = self.get_estimator()