Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_basic_tf(self):
"""Single-node TF graph (w/ args) running independently on multiple executors."""
def _map_fun(args, ctx):
import tensorflow as tf
x = tf.constant(args['x'])
y = tf.constant(args['y'])
sum = tf.math.add(x, y)
assert sum.numpy() == 3
args = {'x': 1, 'y': 2}
cluster = TFCluster.run(self.sc, _map_fun, tf_args=args, num_executors=self.num_workers, num_ps=0)
cluster.shutdown()
# for TENSORFLOW mode, each node will load/train/infer entire dataset in memory per original example
cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, args.num_ps, args.tensorboard, TFCluster.InputMode.TENSORFLOW, log_dir=args.model_dir, master_node='master')
cluster.shutdown()
else: # 'spark'
# for SPARK mode, just use CSV format as an example
images = sc.textFile(args.images).map(lambda ln: [float(x) for x in ln.split(',')])
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
dataRDD = images.zip(labels)
if args.mode == 'train':
cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, args.num_ps, args.tensorboard, TFCluster.InputMode.SPARK, log_dir=args.model_dir, master_node='master')
cluster.train(dataRDD, args.epochs)
cluster.shutdown()
else:
# Note: using "parallel" inferencing, not "cluster"
# each node loads the model and runs independently of others
cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, 0, args.tensorboard, TFCluster.InputMode.SPARK, log_dir=args.model_dir)
resultRDD = cluster.inference(dataRDD)
resultRDD.saveAsTextFile(args.output)
cluster.shutdown()