How to use the tensorflowonspark.dfutil.loadTFRecords function in tensorflowonspark

To help you get started, we’ve selected a few tensorflowonspark examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github yahoo / TensorFlowOnSpark / test / test_dfutil.py View on Github external
def test_dfutils(self):
    # create a DataFrame of a single row consisting of standard types (str, int, int_array, float, float_array, binary)
    row1 = ('text string', 1, [2, 3, 4, 5], -1.1, [-2.2, -3.3, -4.4, -5.5], bytearray(b'\xff\xfe\xfd\xfc'))
    rdd = self.sc.parallelize([row1])
    df1 = self.spark.createDataFrame(rdd, ['a', 'b', 'c', 'd', 'e', 'f'])
    print("schema: {}".format(df1.schema))

    # save the DataFrame as TFRecords
    dfutil.saveAsTFRecords(df1, self.tfrecord_dir)
    self.assertTrue(os.path.isdir(self.tfrecord_dir))

    # reload the DataFrame from exported TFRecords
    df2 = dfutil.loadTFRecords(self.sc, self.tfrecord_dir, binary_features=['f'])
    row2 = df2.take(1)[0]

    print("row_saved: {}".format(row1))
    print("row_loaded: {}".format(row2))

    # confirm loaded values match original/saved values
    self.assertEqual(row1[0], row2['a'])
    self.assertEqual(row1[1], row2['b'])
    self.assertEqual(row1[2], row2['c'])
    self.assertAlmostEqual(row1[3], row2['d'], 6)
    for i in range(len(row1[4])):
      self.assertAlmostEqual(row1[4][i], row2['e'][i], 6)
    print("type(f): {}".format(type(row2['f'])))
    for i in range(len(row1[5])):
      self.assertEqual(row1[5][i], row2['f'][i])