Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
parser.add_argument("--export_dir", help="directory to export saved_model")
parser.add_argument("--images", help="HDFS path to MNIST images in parallelized CSV format")
parser.add_argument("--input_mode", help="input mode (tf|spark)", default="tf")
parser.add_argument("--labels", help="HDFS path to MNIST labels in parallelized CSV format")
parser.add_argument("--model_dir", help="directory to write model checkpoints")
parser.add_argument("--num_ps", help="number of ps nodes", type=int, default=1)
parser.add_argument("--steps_per_epoch", help="number of steps per epoch", type=int, default=300)
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
args = parser.parse_args()
print("args:", args)
if args.input_mode == 'tf':
cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, args.num_ps, args.tensorboard, TFCluster.InputMode.TENSORFLOW, log_dir=args.model_dir)
else: # args.input_mode == 'spark':
cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, args.num_ps, args.tensorboard, TFCluster.InputMode.SPARK, log_dir=args.model_dir)
images = sc.textFile(args.images).map(lambda ln: [float(x) for x in ln.split(',')])
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
dataRDD = images.zip(labels)
cluster.train(dataRDD, args.epochs)
cluster.shutdown()
num_executors = int(executors) if executors is not None else 1
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", help="number of records per batch", type=int, default=64)
parser.add_argument("--buffer_size", help="size of shuffle buffer", type=int, default=10000)
parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors)
parser.add_argument("--epochs", help="number of epochs", type=int, default=3)
parser.add_argument("--learning_rate", help="learning rate", type=float, default=1e-4)
parser.add_argument("--model_dir", help="path to save checkpoint", default="mnist_model")
parser.add_argument("--export_dir", help="path to export saved_model", default="mnist_export")
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
args = parser.parse_args()
print("args:", args)
cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, num_ps=0, tensorboard=args.tensorboard, input_mode=TFCluster.InputMode.TENSORFLOW, log_dir=args.model_dir, master_node='chief', eval_node=True)
cluster.shutdown(grace_secs=120)
checkpoint_dir=FLAGS.train_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
if __name__ == '__main__':
sc = SparkContext(conf=SparkConf().setAppName("cifar10_train"))
num_executors = int(sc._conf.get("spark.executor.instances"))
num_ps = 0
cluster = TFCluster.run(sc, main_fun, sys.argv, num_executors, num_ps, False, TFCluster.InputMode.TENSORFLOW)
cluster.shutdown()
parser.add_argument("--images_labels", help="HDFS path to MNIST image_label files in parallelized format")
parser.add_argument("--mode", help="train|inference", default="train")
parser.add_argument("--model", help="HDFS path to save/load model during train/test", default="mnist_model")
parser.add_argument("--num_ps", help="number of ps nodes", default=1)
parser.add_argument("--output", help="HDFS path to save test/inference output", default="predictions")
parser.add_argument("--rdma", help="use rdma connection", default=False)
parser.add_argument("--readers", help="number of reader/enqueue threads per worker", type=int, default=10)
parser.add_argument("--shuffle_size", help="size of shuffle buffer", type=int, default=1000)
parser.add_argument("--steps", help="maximum number of steps", type=int, default=1000)
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
args = parser.parse_args()
print("args:", args)
print("{0} ===== Start".format(datetime.now().isoformat()))
cluster = TFCluster.run(sc, mnist_dist_dataset.map_fun, args, args.cluster_size, args.num_ps, args.tensorboard,
TFCluster.InputMode.TENSORFLOW, driver_ps_nodes=args.driver_ps_nodes)
cluster.shutdown()
print("{0} ===== Stop".format(datetime.now().isoformat()))