spark_with_tfrecords

spark-tensorflow-connector

spark-tensorflow-connector是tensorflow提供的解决如何利用spark生成tfrecord文件的解决方案

安装与配置

  1. git clone https://github.com/tensorflow/ecosystem
  2. 安装并配置maven

    2.1 下载maven压缩包 http://maven.apache.org/download.cgi

    2.2 解压缩到指定目录

    2.3 配置~/.bash_profile

    1
    2
    3
    export M2_HOME=/Users/xxx/apache-maven-3.3.3

    export PATH=$PATH:$M2_HOME/bin

    2.4 source ~/.bash_profile使配置生效

    2.5 mvn -v 查看maven配置成功

    (如果失败,请配置JAVA_HOME)

  3. intellij打开 xx/ecosystem/hadoop 工程, 执行maven clean & install

    若出现一下错误

    1
    [ERROR] Failed to execute goal org.apache.maven.plugins:maven-javadoc-plugin:2.9:jar (attach-javadocs) on project template.querylist: MavenReportException: Error while creating archive: Unable to find javadoc command: The environment variable JAVA_HOME is not correctly set. -> [Help 1]

    可以修改POM配置

    1
    2
    3
    <properties>
    <javadocExecutable>${java.home}/../bin/javadoc</javadocExecutable>
    </properties>
  4. intellij打开 xx/ecosystem/spark/spark-tensorflow-connector工程, 执行maven clean & install 得到最终到spark-tensorflow-connector_2.11-1.10.0.jar

使用

官方给出的scala example:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
package com.xxx


import scala.collection.JavaConversions._;
import scala.collection.JavaConverters._;
import collection.JavaConversions._
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.SparkConf

import org.apache.spark.sql.SparkSession

import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._

object TFRecordsExample {
def main(args: Array[String]): Unit = {
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)

val spark = SparkSession.builder().master("local[4]").appName("tfrecords_examples").getOrCreate()

val path = "file/test-output.tfrecord"
val testRows: Array[Row] = Array(
new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")),
new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2")))

val schema = StructType(List(
StructField("id", IntegerType),
StructField("IntegerCol", IntegerType),
StructField("LongCol", LongType),
StructField("FloatCol", FloatType),
StructField("DoubleCol", DoubleType),
StructField("VectorCol", ArrayType(DoubleType, true)),
StructField("StringCol", StringType)))

val rdd = spark.sparkContext.parallelize(testRows)

//Save DataFrame as TFRecords
val df: DataFrame = spark.createDataFrame(rdd, schema)
df.write.format("tfrecords").option("recordType", "SequenceExample").save(path)

//Read TFRecords into DataFrame.
//The DataFrame schema is inferred from the TFRecords if no custom schema is provided.
val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "SequenceExample").load(path)
importedDf1.show()

//Read TFRecords into DataFrame using custom schema
val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path)
importedDf2.show()

}

}

pom文件:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<groupId>com.xxx.ai</groupId>
<artifactId>tfrecordsProj</artifactId>
<version>1.0-SNAPSHOT</version>


<dependencies>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>spark-tensorflow-connector_2.11</artifactId>
<version>1.10.0</version>
</dependency>
<dependency>
<groupId>log4j</groupId>
<artifactId>log4j</artifactId>
<version>1.2.16</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.11</artifactId>
<version>2.1.0</version>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>2.11.8</version>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-compiler</artifactId>
<version>2.12.5</version>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-reflect</artifactId>
<version>2.11.8</version>
</dependency>
<dependency>
<groupId>org.scala-lang.modules</groupId>
<artifactId>scala-xml_2.12</artifactId>
<version>1.1.0</version>
</dependency>
<dependency>
<groupId>org.scala-lang.modules</groupId>
<artifactId>scala-parser-combinators_2.12</artifactId>

<version>1.1.0</version>
</dependency>
</dependencies>
</project>

官方给出的python example

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from pyspark.sql import SparkSession
from pyspark.sql.types import *


spark = SparkSession.builder.master("local[4]").appName("tfrecords_examples").getOrCreate()
path = '/Users/linglingsu/Documents/intellij_workspace/tfrecordsProj/file/test-output.tfrecord'
fields = [StructField("id", IntegerType()), StructField("IntegerCol", IntegerType()),
StructField("LongCol", LongType()), StructField("FloatCol", FloatType()),
StructField("DoubleCol", DoubleType()), StructField("VectorCol", ArrayType(DoubleType(), True)),
StructField("StringCol", StringType())]
schema = StructType(fields)
df = spark.read.format("tfrecords").option("recordType", "SequenceExample").load(path)
df_pandas = df.toPandas()
print(df_pandas.head())

运行命令如下:

1
2

pyspark-submit --jars xxx/ecosystem/spark/spark-tensorflow-connector/target/spark-tensorflow-connector_2.11-1.10.0.jar test.py