From c211d99c2e680cbb8fa4182b6786ad30f1f2bdaf Mon Sep 17 00:00:00 2001
From: wanglihui <949764788@qq.com>
Date: Fri, 23 Oct 2020 10:02:28 +0800
Subject: [PATCH] =?UTF-8?q?=E8=87=AA=E5=AE=9A=E4=B9=89ArangoRDD?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
ip-learning-spark/pom.xml | 17 ++
.../src/main/resources/application.properties | 2 +-
.../scala/cn/ac/iie/spark/ArangoSpark.scala | 131 ++++++++++++++
.../partition/QueryArangoPartition.scala | 7 +
.../cn/ac/iie/spark/rdd/ArangoOptions.scala | 34 ++++
.../scala/cn/ac/iie/spark/rdd/ArangoRdd.scala | 81 +++++++++
.../cn/ac/iie/spark/rdd/ReadOptions.scala | 93 ++++++++++
.../cn/ac/iie/spark/rdd/WriteOptions.scala | 119 +++++++++++++
.../main/scala/cn/ac/iie/spark/spark.scala | 139 +++++++++++++++
.../cn/ac/iie/spark/vpack/VPackUtils.scala | 160 ++++++++++++++++++
.../cn/ac/iie/utils/SparkSessionUtil.scala | 3 +
.../test/scala/cn/ac/iie/spark/RDDTest.scala | 40 +++++
12 files changed, 825 insertions(+), 1 deletion(-)
create mode 100644 ip-learning-spark/src/main/scala/cn/ac/iie/spark/ArangoSpark.scala
create mode 100644 ip-learning-spark/src/main/scala/cn/ac/iie/spark/partition/QueryArangoPartition.scala
create mode 100644 ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/ArangoOptions.scala
create mode 100644 ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/ArangoRdd.scala
create mode 100644 ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/ReadOptions.scala
create mode 100644 ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/WriteOptions.scala
create mode 100644 ip-learning-spark/src/main/scala/cn/ac/iie/spark/spark.scala
create mode 100644 ip-learning-spark/src/main/scala/cn/ac/iie/spark/vpack/VPackUtils.scala
create mode 100644 ip-learning-spark/src/test/scala/cn/ac/iie/spark/RDDTest.scala
diff --git a/ip-learning-spark/pom.xml b/ip-learning-spark/pom.xml
index 204fa68..7ea3c38 100644
--- a/ip-learning-spark/pom.xml
+++ b/ip-learning-spark/pom.xml
@@ -63,6 +63,18 @@
6.6.3
+
+ com.arangodb
+ velocypack-module-jdk8
+ 1.1.0
+
+
+
+ com.arangodb
+ velocypack-module-scala_2.11
+ 1.2.0
+
+
org.scala-lang
scala-library
@@ -75,6 +87,11 @@
3.2.0
+
+ org.scala-lang.modules
+ scala-xml_2.11
+ 1.0.4
+
org.scala-tools
diff --git a/ip-learning-spark/src/main/resources/application.properties b/ip-learning-spark/src/main/resources/application.properties
index 0010b23..c2e81ea 100644
--- a/ip-learning-spark/src/main/resources/application.properties
+++ b/ip-learning-spark/src/main/resources/application.properties
@@ -21,7 +21,7 @@ arangoDB.port=8529
arangoDB.user=upsert
arangoDB.password=ceiec2018
#arangoDB.DB.name=insert_iplearn_index
-arangoDB.DB.name=ip-learning-test-0
+arangoDB.DB.name=iplearn_media_domain
arangoDB.ttl=3600
thread.pool.number=5
diff --git a/ip-learning-spark/src/main/scala/cn/ac/iie/spark/ArangoSpark.scala b/ip-learning-spark/src/main/scala/cn/ac/iie/spark/ArangoSpark.scala
new file mode 100644
index 0000000..b492f9a
--- /dev/null
+++ b/ip-learning-spark/src/main/scala/cn/ac/iie/spark/ArangoSpark.scala
@@ -0,0 +1,131 @@
+/*
+ * DISCLAIMER
+ *
+ * Copyright 2016 ArangoDB GmbH, Cologne, Germany
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * Copyright holder is ArangoDB GmbH, Cologne, Germany
+ *
+ * author Mark - mark at arangodb.com
+ */
+
+package cn.ac.iie.spark
+
+import cn.ac.iie.spark.rdd.{ArangoRdd, ReadOptions, WriteOptions}
+import cn.ac.iie.spark.vpack.VPackUtils
+import org.apache.spark.SparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
+
+import scala.collection.JavaConverters.seqAsJavaListConverter
+import scala.reflect.ClassTag
+
+object ArangoSpark {
+
+ /**
+ * Save data from rdd into ArangoDB
+ *
+ * @param rdd the rdd with the data to save
+ * @param collection the collection to save in
+ */
+ def save[T](rdd: RDD[T], collection: String): Unit =
+ save(rdd, collection, WriteOptions())
+
+ /**
+ * Save data from rdd into ArangoDB
+ *
+ * @param rdd the rdd with the data to save
+ * @param collection the collection to save in
+ * @param options additional write options
+ */
+ def save[T](rdd: RDD[T], collection: String, options: WriteOptions): Unit =
+ saveRDD(rdd, collection, options, (x: Iterator[T]) => x)
+
+ /**
+ * Save data from dataset into ArangoDB
+ *
+ * @param dataset the dataset with data to save
+ * @param collection the collection to save in
+ */
+ def save[T](dataset: Dataset[T], collection: String): Unit =
+ saveRDD(dataset.rdd, collection, WriteOptions(), (x: Iterator[T]) => x)
+
+ /**
+ * Save data from dataset into ArangoDB
+ *
+ * @param dataset the dataset with data to save
+ * @param collection the collection to save in
+ * @param options additional write options
+ */
+ def save[T](dataset: Dataset[T], collection: String, options: WriteOptions): Unit =
+ saveRDD(dataset.rdd, collection, options, (x: Iterator[T]) => x)
+
+ /**
+ * Save data from dataframe into ArangoDB
+ *
+ * @param dataframe the dataframe with data to save
+ * @param collection the collection to save in
+ * @param options additional write options
+ */
+ def saveDF(dataframe: DataFrame, collection: String): Unit =
+ saveRDD[Row](dataframe.rdd, collection, WriteOptions(), (x: Iterator[Row]) => x.map { y => VPackUtils.rowToVPack(y) })
+
+ /**
+ * Save data from dataframe into ArangoDB
+ *
+ * @param dataframe the dataframe with data to save
+ * @param collection the collection to save in
+ * @param options additional write options
+ */
+ def saveDF(dataframe: DataFrame, collection: String, options: WriteOptions): Unit =
+ saveRDD[Row](dataframe.rdd, collection, options, (x: Iterator[Row]) => x.map { y => VPackUtils.rowToVPack(y) })
+
+ private def saveRDD[T](rdd: RDD[T], collection: String, options: WriteOptions, map: Iterator[T] => Iterator[Any]): Unit = {
+ val writeOptions = createWriteOptions(options, rdd.sparkContext.getConf)
+ rdd.foreachPartition { p =>
+ if (p.nonEmpty) {
+ val arangoDB = createArangoBuilder(writeOptions).build()
+ val col = arangoDB.db(writeOptions.database).collection(collection)
+ val docs = map(p).toList.asJava
+ writeOptions.method match {
+ case WriteOptions.INSERT => col.insertDocuments(docs)
+ case WriteOptions.UPDATE => col.updateDocuments(docs)
+ case WriteOptions.REPLACE => col.replaceDocuments(docs)
+ }
+
+ arangoDB.shutdown()
+ }
+ }
+ }
+
+ /**
+ * Load data from ArangoDB into rdd
+ *
+ * @param sparkContext the sparkContext containing the ArangoDB configuration
+ * @param collection the collection to load data from
+ */
+ def load[T: ClassTag](sparkContext: SparkContext, collection: String): ArangoRdd[T] =
+ load(sparkContext, collection, ReadOptions())
+
+ /**
+ * Load data from ArangoDB into rdd
+ *
+ * @param sparkContext the sparkContext containing the ArangoDB configuration
+ * @param collection the collection to load data from
+ * @param additional read options
+ */
+ def load[T: ClassTag](sparkContext: SparkContext, collection: String, options: ReadOptions): ArangoRdd[T] =
+ new ArangoRdd[T](sparkContext, createReadOptions(options, sparkContext.getConf).copy(collection = collection))
+
+}
diff --git a/ip-learning-spark/src/main/scala/cn/ac/iie/spark/partition/QueryArangoPartition.scala b/ip-learning-spark/src/main/scala/cn/ac/iie/spark/partition/QueryArangoPartition.scala
new file mode 100644
index 0000000..7cf134f
--- /dev/null
+++ b/ip-learning-spark/src/main/scala/cn/ac/iie/spark/partition/QueryArangoPartition.scala
@@ -0,0 +1,7 @@
+package cn.ac.iie.spark.partition
+
+import org.apache.spark.Partition
+
+class QueryArangoPartition(idx: Int, val offset: Long, val separate: Long) extends Partition{
+ override def index: Int = idx
+}
diff --git a/ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/ArangoOptions.scala b/ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/ArangoOptions.scala
new file mode 100644
index 0000000..885d077
--- /dev/null
+++ b/ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/ArangoOptions.scala
@@ -0,0 +1,34 @@
+package cn.ac.iie.spark.rdd
+
+import com.arangodb.Protocol
+import com.arangodb.entity.LoadBalancingStrategy
+
+trait ArangoOptions {
+
+ def database: String = "_system"
+
+ def hosts: Option[String] = None
+
+ def user: Option[String] = None
+
+ def password: Option[String] = None
+
+ def useSsl: Option[Boolean] = None
+
+ def sslKeyStoreFile: Option[String] = None
+
+ def sslPassPhrase: Option[String] = None
+
+ def sslProtocol: Option[String] = None
+
+ def protocol: Option[Protocol] = None
+
+ def maxConnections: Option[Int] = None
+
+ def acquireHostList: Option[Boolean] = None
+
+ def acquireHostListInterval: Option[Int] = None
+
+ def loadBalancingStrategy: Option[LoadBalancingStrategy] = None
+
+}
\ No newline at end of file
diff --git a/ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/ArangoRdd.scala b/ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/ArangoRdd.scala
new file mode 100644
index 0000000..adf3e1b
--- /dev/null
+++ b/ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/ArangoRdd.scala
@@ -0,0 +1,81 @@
+package cn.ac.iie.spark.rdd
+
+import scala.collection.JavaConverters.asScalaIteratorConverter
+import cn.ac.iie.config.ApplicationConfig
+import cn.ac.iie.service.update.UpdateDocument
+import cn.ac.iie.spark
+import cn.ac.iie.spark.partition.QueryArangoPartition
+import com.arangodb.ArangoCursor
+import org.apache.spark.{Partition, SparkContext, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.slf4j.LoggerFactory
+
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
+class ArangoRdd[T: ClassTag](@transient override val sparkContext: SparkContext,
+ val options: ReadOptions
+ ) extends RDD[T](sparkContext, Nil) {
+
+ private val LOG = LoggerFactory.getLogger(UpdateDocument.getClass)
+
+ override def compute(split: Partition, context: TaskContext): Iterator[T] = {
+
+ createCursor(split.asInstanceOf[QueryArangoPartition]).asScala
+ }
+
+ override protected def getPartitions: Array[Partition] = {
+ val partitions = ArrayBuffer[Partition]()
+ val total = getCountTotal
+ for (i <- 0 until ApplicationConfig.SPARK_SQL_SHUFFLE_PARTITIONS) {
+ val partition = getPartition(i, total)
+ partitions += partition
+ }
+ partitions.toArray
+ }
+
+ private def createCursor(split: QueryArangoPartition)(implicit clazz: ClassTag[T]): ArangoCursor[T] = {
+
+ var arangoCursor:ArangoCursor[T] = null
+ val arangoDB = spark.createArangoBuilder(options).build()
+ try {
+ val offset = split.offset
+ val separate = split.separate
+ val collection = options.collection
+ val sql = s"FOR doc IN $collection limit $offset,$separate RETURN doc"
+ LOG.info(sql)
+ arangoCursor = arangoDB.db(options.database).query(sql,clazz.runtimeClass.asInstanceOf[Class[T]])
+ }catch {
+ case e: Exception => LOG.error("创建Cursor异常")
+ }finally {
+ arangoDB.shutdown()
+ }
+ arangoCursor
+ }
+
+ override def repartition(numPartitions: Int)(implicit ord: Ordering[T]): RDD[T] = super.repartition(numPartitions)
+
+ private def getPartition(idx: Int, countTotal: Long): QueryArangoPartition = {
+ val sepNum = countTotal / ApplicationConfig.THREAD_POOL_NUMBER + 1
+ val offsetNum = idx * sepNum
+ new QueryArangoPartition(idx, offsetNum, sepNum)
+ }
+
+ override def count(): Long = getCountTotal
+
+ private def getCountTotal: Long = {
+ val arangoDB = spark.createArangoBuilder(options).build()
+ var cnt = 0L
+ val sql = "RETURN LENGTH(" + options.collection + ")"
+ LOG.info(sql)
+ try {
+ val longs = arangoDB.db(options.database).query(sql, classOf[Long])
+ while (longs.hasNext) cnt = longs.next
+ } catch {
+ case e: Exception => LOG.error(sql + "执行异常")
+ }finally {
+ arangoDB.shutdown()
+ }
+ cnt
+ }
+}
diff --git a/ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/ReadOptions.scala b/ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/ReadOptions.scala
new file mode 100644
index 0000000..875ea75
--- /dev/null
+++ b/ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/ReadOptions.scala
@@ -0,0 +1,93 @@
+/*
+ * DISCLAIMER
+ *
+ * Copyright 2016 ArangoDB GmbH, Cologne, Germany
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * Copyright holder is ArangoDB GmbH, Cologne, Germany
+ *
+ * author Mark - mark at arangodb.com
+ */
+
+package cn.ac.iie.spark.rdd
+
+import cn.ac.iie.spark.partition.QueryArangoPartition
+import com.arangodb.Protocol
+import com.arangodb.entity.LoadBalancingStrategy
+
+case class ReadOptions(override val database: String = "_system",
+ val collection: String = null,
+ partitioner: QueryArangoPartition = new QueryArangoPartition(0,0,0),
+ override val hosts: Option[String] = None,
+ override val user: Option[String] = None,
+ override val password: Option[String] = None,
+ override val useSsl: Option[Boolean] = None,
+ override val sslKeyStoreFile: Option[String] = None,
+ override val sslPassPhrase: Option[String] = None,
+ override val sslProtocol: Option[String] = None,
+ override val protocol: Option[Protocol] = None,
+ override val maxConnections: Option[Int] = None,
+ override val acquireHostList: Option[Boolean] = None,
+ override val acquireHostListInterval: Option[Int] = None,
+ override val loadBalancingStrategy: Option[LoadBalancingStrategy] = None) extends ArangoOptions {
+
+ def this() = this(database = "_system")
+
+ def database(database: String): ReadOptions = copy(database = database)
+
+ def collection(collection: String): ReadOptions = copy(collection = collection)
+
+ def hosts(hosts: String): ReadOptions = copy(hosts = Some(hosts))
+
+ def user(user: String): ReadOptions = copy(user = Some(user))
+
+ def password(password: String): ReadOptions = copy(password = Some(password))
+
+ def useSsl(useSsl: Boolean): ReadOptions = copy(useSsl = Some(useSsl))
+
+ def sslKeyStoreFile(sslKeyStoreFile: String): ReadOptions = copy(sslKeyStoreFile = Some(sslKeyStoreFile))
+
+ def sslPassPhrase(sslPassPhrase: String): ReadOptions = copy(sslPassPhrase = Some(sslPassPhrase))
+
+ def sslProtocol(sslProtocol: String): ReadOptions = copy(sslProtocol = Some(sslProtocol))
+
+ def protocol(protocol: Protocol): ReadOptions = copy(protocol = Some(protocol))
+
+ def maxConnections(maxConnections: Int): ReadOptions = copy(maxConnections = Some(maxConnections))
+
+ def acquireHostList(acquireHostList: Boolean): ReadOptions = copy(acquireHostList = Some(acquireHostList))
+
+ def acquireHostListInterval(acquireHostListInterval: Int): ReadOptions = copy(acquireHostListInterval = Some(acquireHostListInterval))
+
+ def loadBalancingStrategy(loadBalancingStrategy: LoadBalancingStrategy): ReadOptions = copy(loadBalancingStrategy = Some(loadBalancingStrategy))
+
+ def copy(database: String = database,
+ collection: String = collection,
+ partitioner: QueryArangoPartition = partitioner,
+ hosts: Option[String] = hosts,
+ user: Option[String] = user,
+ password: Option[String] = password,
+ useSsl: Option[Boolean] = useSsl,
+ sslKeyStoreFile: Option[String] = sslKeyStoreFile,
+ sslPassPhrase: Option[String] = sslPassPhrase,
+ sslProtocol: Option[String] = sslProtocol,
+ protocol: Option[Protocol] = protocol,
+ maxConnections: Option[Int] = maxConnections,
+ acquireHostList: Option[Boolean] = acquireHostList,
+ acquireHostListInterval: Option[Int] = acquireHostListInterval,
+ loadBalancingStrategy: Option[LoadBalancingStrategy] = loadBalancingStrategy): ReadOptions = {
+ ReadOptions(database, collection, partitioner, hosts, user, password, useSsl, sslKeyStoreFile, sslPassPhrase, sslProtocol, protocol, maxConnections, acquireHostList, acquireHostListInterval, loadBalancingStrategy)
+ }
+
+}
diff --git a/ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/WriteOptions.scala b/ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/WriteOptions.scala
new file mode 100644
index 0000000..46f3c80
--- /dev/null
+++ b/ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/WriteOptions.scala
@@ -0,0 +1,119 @@
+/*
+ * DISCLAIMER
+ *
+ * Copyright 2016 ArangoDB GmbH, Cologne, Germany
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * Copyright holder is ArangoDB GmbH, Cologne, Germany
+ *
+ * author Mark - mark at arangodb.com
+ */
+
+package cn.ac.iie.spark.rdd
+
+import javax.net.ssl.SSLContext
+import com.arangodb.Protocol
+import com.arangodb.entity.LoadBalancingStrategy
+
+case class WriteOptions(override val database: String = "_system",
+ val method: WriteOptions.Method = WriteOptions.INSERT,
+ override val hosts: Option[String] = None,
+ override val user: Option[String] = None,
+ override val password: Option[String] = None,
+ override val useSsl: Option[Boolean] = None,
+ override val sslKeyStoreFile: Option[String] = None,
+ override val sslPassPhrase: Option[String] = None,
+ override val sslProtocol: Option[String] = None,
+ override val protocol: Option[Protocol] = None,
+ override val maxConnections: Option[Int] = None,
+ override val acquireHostList: Option[Boolean] = None,
+ override val acquireHostListInterval: Option[Int] = None,
+ override val loadBalancingStrategy: Option[LoadBalancingStrategy] = None) extends ArangoOptions {
+ import WriteOptions._
+
+ def this() = this(database = "_system")
+
+ def database(database: String): WriteOptions = copy(database = database)
+
+ def method(method: Method): WriteOptions = copy(method = method)
+
+ def hosts(hosts: String): WriteOptions = copy(hosts = Some(hosts))
+
+ def user(user: String): WriteOptions = copy(user = Some(user))
+
+ def password(password: String): WriteOptions = copy(password = Some(password))
+
+ def useSsl(useSsl: Boolean): WriteOptions = copy(useSsl = Some(useSsl))
+
+ def sslKeyStoreFile(sslKeyStoreFile: String): WriteOptions = copy(sslKeyStoreFile = Some(sslKeyStoreFile))
+
+ def sslPassPhrase(sslPassPhrase: String): WriteOptions = copy(sslPassPhrase = Some(sslPassPhrase))
+
+ def sslProtocol(sslProtocol: String): WriteOptions = copy(sslProtocol = Some(sslProtocol))
+
+ def protocol(protocol: Protocol): WriteOptions = copy(protocol = Some(protocol))
+
+ def maxConnections(maxConnections: Int): WriteOptions = copy(maxConnections = Some(maxConnections))
+
+ def acquireHostList(acquireHostList: Boolean): WriteOptions = copy(acquireHostList = Some(acquireHostList))
+
+ def acquireHostListInterval(acquireHostListInterval: Int): WriteOptions = copy(acquireHostListInterval = Some(acquireHostListInterval))
+
+ def loadBalancingStrategy(loadBalancingStrategy: LoadBalancingStrategy): WriteOptions = copy(loadBalancingStrategy = Some(loadBalancingStrategy))
+
+ def copy(database: String = database,
+ method: Method = method,
+ hosts: Option[String] = hosts,
+ user: Option[String] = user,
+ password: Option[String] = password,
+ useSsl: Option[Boolean] = useSsl,
+ sslKeyStoreFile: Option[String] = sslKeyStoreFile,
+ sslPassPhrase: Option[String] = sslPassPhrase,
+ sslProtocol: Option[String] = sslProtocol,
+ protocol: Option[Protocol] = protocol,
+ maxConnections: Option[Int] = maxConnections,
+ acquireHostList: Option[Boolean] = acquireHostList,
+ acquireHostListInterval: Option[Int] = acquireHostListInterval,
+ loadBalancingStrategy: Option[LoadBalancingStrategy] = loadBalancingStrategy): WriteOptions = {
+ WriteOptions(database, method, hosts, user, password, useSsl, sslKeyStoreFile, sslPassPhrase, sslProtocol, protocol, maxConnections, acquireHostList, acquireHostListInterval, loadBalancingStrategy)
+ }
+
+}
+
+object WriteOptions {
+
+ /**
+ * method to save documents to arangodb
+ */
+ sealed trait Method
+
+ /**
+ * save documents by inserting
+ * @see [[com.arangodb.ArangoCollection#insertDocuments(java.util.Collection)]]
+ */
+ case object INSERT extends Method
+
+ /**
+ * save documents by updating
+ * @see [[com.arangodb.ArangoCollection#updateDocuments(java.util.Collection)]]
+ */
+ case object UPDATE extends Method
+
+ /**
+ * save documents by replacing
+ * @see [[com.arangodb.ArangoCollection#replaceDocuments(java.util.Collection)]]
+ */
+ case object REPLACE extends Method
+
+}
\ No newline at end of file
diff --git a/ip-learning-spark/src/main/scala/cn/ac/iie/spark/spark.scala b/ip-learning-spark/src/main/scala/cn/ac/iie/spark/spark.scala
new file mode 100644
index 0000000..7dcfbc9
--- /dev/null
+++ b/ip-learning-spark/src/main/scala/cn/ac/iie/spark/spark.scala
@@ -0,0 +1,139 @@
+/*
+ * DISCLAIMER
+ *
+ * Copyright 2016 ArangoDB GmbH, Cologne, Germany
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * Copyright holder is ArangoDB GmbH, Cologne, Germany
+ *
+ * author Mark - mark at arangodb.com
+ */
+
+package cn.ac.iie
+
+import java.io.FileInputStream
+import java.security.KeyStore
+
+import cn.ac.iie.spark.rdd.{ArangoOptions, ReadOptions, WriteOptions}
+import com.arangodb.{ArangoDB, ArangoDBException, Protocol}
+import com.arangodb.entity.LoadBalancingStrategy
+import com.arangodb.velocypack.module.jdk8.VPackJdk8Module
+import com.arangodb.velocypack.module.scala.VPackScalaModule
+import javax.net.ssl.{KeyManagerFactory, SSLContext, TrustManagerFactory}
+import org.apache.spark.SparkConf
+
+import scala.util.Try
+
+package object spark {
+
+ val PropertyHosts = "arangodb.hosts"
+ val PropertyUser = "arangodb.user"
+ val PropertyPassword = "arangodb.password"
+ val PropertyUseSsl = "arangodb.useSsl"
+ val PropertySslKeyStoreFile = "arangodb.ssl.keyStoreFile"
+ val PropertySslPassPhrase = "arangodb.ssl.passPhrase"
+ val PropertySslProtocol = "arangodb.ssl.protocol"
+ val PropertyProtocol = "arangodb.protocol"
+ val PropertyMaxConnections = "arangodb.maxConnections"
+ val PropertyAcquireHostList = "arangodb.acquireHostList"
+ val PropertyAcquireHostListInterval = "arangodb.acquireHostListInterval"
+ val PropertyLoadBalancingStrategy = "arangodb.loadBalancingStrategy"
+
+ private[spark] def createReadOptions(options: ReadOptions, sc: SparkConf): ReadOptions = {
+ options.copy(
+ hosts = options.hosts.orElse(some(sc.get(PropertyHosts, null))),
+ user = options.user.orElse(some(sc.get(PropertyUser, null))),
+ password = options.password.orElse(some(sc.get(PropertyPassword, null))),
+ useSsl = options.useSsl.orElse(some(Try(sc.get(PropertyUseSsl, null).toBoolean).getOrElse(false))),
+ sslKeyStoreFile = options.sslKeyStoreFile.orElse(some(sc.get(PropertySslKeyStoreFile, null))),
+ sslPassPhrase = options.sslPassPhrase.orElse(some(sc.get(PropertySslPassPhrase, null))),
+ sslProtocol = options.sslProtocol.orElse(some(sc.get(PropertySslProtocol, null))),
+ protocol = options.protocol.orElse(some(Protocol.valueOf(sc.get(PropertyProtocol, "VST")))),
+ maxConnections = options.maxConnections.orElse(some(Try(sc.get(PropertyMaxConnections, null).toInt).getOrElse(1))),
+ acquireHostList = options.acquireHostList.orElse(some(Try(sc.get(PropertyAcquireHostList, null).toBoolean).getOrElse(false))),
+ acquireHostListInterval = options.acquireHostListInterval.orElse(some(Try(sc.get(PropertyAcquireHostListInterval, null).toInt).getOrElse(60000))),
+ loadBalancingStrategy = options.loadBalancingStrategy.orElse(some(LoadBalancingStrategy.valueOf(sc.get(PropertyLoadBalancingStrategy, "NONE")))))
+ }
+
+ private[spark] def createWriteOptions(options: WriteOptions, sc: SparkConf): WriteOptions = {
+ options.copy(
+ hosts = options.hosts.orElse(some(sc.get(PropertyHosts, null))),
+ user = options.user.orElse(some(sc.get(PropertyUser, null))),
+ password = options.password.orElse(some(sc.get(PropertyPassword, null))),
+ useSsl = options.useSsl.orElse(some(Try(sc.get(PropertyUseSsl, null).toBoolean).getOrElse(false))),
+ sslKeyStoreFile = options.sslKeyStoreFile.orElse(some(sc.get(PropertySslKeyStoreFile, null))),
+ sslPassPhrase = options.sslPassPhrase.orElse(some(sc.get(PropertySslPassPhrase, null))),
+ sslProtocol = options.sslProtocol.orElse(some(sc.get(PropertySslProtocol, null))),
+ protocol = options.protocol.orElse(some(Protocol.valueOf(sc.get(PropertyProtocol, "VST")))),
+ maxConnections = options.maxConnections.orElse(some(Try(sc.get(PropertyMaxConnections, null).toInt).getOrElse(1))),
+ acquireHostList = options.acquireHostList.orElse(some(Try(sc.get(PropertyAcquireHostList, null).toBoolean).getOrElse(false))),
+ acquireHostListInterval = options.acquireHostListInterval.orElse(some(Try(sc.get(PropertyAcquireHostListInterval, null).toInt).getOrElse(60000))),
+ loadBalancingStrategy = options.loadBalancingStrategy.orElse(some(LoadBalancingStrategy.valueOf(sc.get(PropertyLoadBalancingStrategy, "NONE")))))
+ }
+
+ private[spark] def createArangoBuilder(options: ArangoOptions): ArangoDB.Builder = {
+ val builder = new ArangoDB.Builder()
+ builder.registerModules(new VPackJdk8Module, new VPackScalaModule)
+ options.hosts.foreach { hosts(_).foreach(host => builder.host(host._1, host._2)) }
+ options.user.foreach { builder.user(_) }
+ options.password.foreach { builder.password(_) }
+ options.useSsl.foreach { builder.useSsl(_) }
+ if (options.sslKeyStoreFile.isDefined && options.sslPassPhrase.isDefined) {
+ builder.sslContext(createSslContext(options.sslKeyStoreFile.get, options.sslPassPhrase.get, options.sslProtocol.getOrElse("TLS")))
+ }
+ options.protocol.foreach { builder.useProtocol(_) }
+ options.maxConnections.foreach { builder.maxConnections(_) }
+ options.acquireHostList.foreach { builder.acquireHostList(_) }
+ options.acquireHostListInterval.foreach { builder.acquireHostListInterval(_) }
+ options.loadBalancingStrategy.foreach { builder.loadBalancingStrategy(_) }
+ builder
+ }
+
+ private def createSslContext(keyStoreFile: String, passPhrase: String, protocol: String): SSLContext = {
+ val ks = KeyStore.getInstance(KeyStore.getDefaultType());
+ val kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
+ ks.load(new FileInputStream(keyStoreFile), passPhrase.toCharArray());
+ kmf.init(ks, passPhrase.toCharArray());
+ val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
+ tmf.init(ks);
+ val sc = SSLContext.getInstance(protocol);
+ sc.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
+ sc
+ }
+
+ private def some(value: String): Option[String] =
+ if (value != null) Some(value) else None
+
+ private def some(value: Int): Option[Int] =
+ Some(value)
+
+ private def some(value: Boolean): Option[Boolean] =
+ Some(value)
+
+ private def some(value: Protocol): Option[Protocol] =
+ Some(value)
+
+ private def some(value: LoadBalancingStrategy): Option[LoadBalancingStrategy] =
+ Some(value)
+
+ private def hosts(hosts: String): List[(String, Int)] =
+ hosts.split(",").map({ x =>
+ val s = x.split(":")
+ if (s.length != 2 || !s(1).matches("[0-9]+"))
+ throw new ArangoDBException(s"Could not load property-value arangodb.hosts=${s}. Expected format ip:port,ip:port,...");
+ else
+ (s(0), s(1).toInt)
+ }).toList
+
+}
\ No newline at end of file
diff --git a/ip-learning-spark/src/main/scala/cn/ac/iie/spark/vpack/VPackUtils.scala b/ip-learning-spark/src/main/scala/cn/ac/iie/spark/vpack/VPackUtils.scala
new file mode 100644
index 0000000..e66b64f
--- /dev/null
+++ b/ip-learning-spark/src/main/scala/cn/ac/iie/spark/vpack/VPackUtils.scala
@@ -0,0 +1,160 @@
+/*
+ * DISCLAIMER
+ *
+ * Copyright 2016 ArangoDB GmbH, Cologne, Germany
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * Copyright holder is ArangoDB GmbH, Cologne, Germany
+ *
+ * author Mark - mark at arangodb.com
+ */
+
+package cn.ac.iie.spark.vpack
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.{
+ ArrayType,
+ BooleanType,
+ DataType,
+ DateType,
+ DecimalType,
+ DoubleType,
+ FloatType,
+ IntegerType,
+ LongType,
+ MapType,
+ NullType,
+ ShortType,
+ StringType,
+ StructField,
+ StructType,
+ TimestampType
+}
+import com.arangodb.velocypack.VPackBuilder
+import com.arangodb.velocypack.VPackSlice
+import com.arangodb.velocypack.ValueType
+
+private[spark] object VPackUtils {
+
+ def rowToVPack(row: Row): VPackSlice = {
+ val builder = new VPackBuilder()
+ if (row == null) {
+ builder.add(ValueType.NULL)
+ } else {
+ builder.add(ValueType.OBJECT)
+ row.schema.fields.zipWithIndex.foreach { addField(_, row, builder) }
+ builder.close()
+ }
+ builder.slice()
+ }
+
+ private def addField(field: (StructField, Int), row: Row, builder: VPackBuilder): Unit = {
+ val name = field._1.name
+ val index = field._2
+ if (row.isNullAt(index)) {
+ builder.add(name, ValueType.NULL)
+ } else {
+ field._1.dataType match {
+ case BooleanType => builder.add(name, java.lang.Boolean.valueOf(row.getBoolean(index)))
+ case DoubleType => builder.add(name, java.lang.Double.valueOf(row.getDouble(index)))
+ case FloatType => builder.add(name, java.lang.Float.valueOf(row.getFloat(index)))
+ case LongType => builder.add(name, java.lang.Long.valueOf(row.getLong(index)))
+ case IntegerType => builder.add(name, java.lang.Integer.valueOf(row.getInt(index)))
+ case ShortType => builder.add(name, java.lang.Short.valueOf(row.getShort(index)))
+ case StringType => builder.add(name, java.lang.String.valueOf(row.getString(index)));
+ case DateType => builder.add(name, row.getDate(index))
+ case TimestampType => builder.add(name, row.getTimestamp(index))
+ case t: DecimalType => builder.add(name, row.getDecimal(index))
+ case t: MapType => {
+ builder.add(name, ValueType.OBJECT)
+ row.getMap[String, Any](index).foreach { case (name, value) => addValue(name, value, builder) }
+ builder.close()
+ }
+ case t: ArrayType => {
+ builder.add(name, ValueType.ARRAY)
+ addValues(row, index, builder, t.elementType)
+ builder.close()
+ }
+ case NullType => builder.add(name, ValueType.NULL)
+ case struct: StructType => builder.add(name, rowToVPack(row.getStruct(index)))
+ case _ => // TODO
+ }
+ }
+ }
+
+ private def addValues(row: Row, index: Int, builder: VPackBuilder, itemType: DataType): Unit = {
+ itemType match {
+ case BooleanType =>
+ row.getSeq[Boolean](index).foreach { value =>
+ addValue(null, value, builder)
+ }
+ case DoubleType =>
+ row.getSeq[Double](index).foreach { value =>
+ addValue(null, value, builder)
+ }
+ case FloatType =>
+ row.getSeq[Float](index).foreach { value =>
+ addValue(null, value, builder)
+ }
+ case LongType =>
+ row.getSeq[Long](index).foreach { value =>
+ addValue(null, value, builder)
+ }
+ case IntegerType =>
+ row.getSeq[Int](index).foreach { value =>
+ addValue(null, value, builder)
+ }
+ case ShortType =>
+ row.getSeq[Short](index).foreach { value =>
+ addValue(null, value, builder)
+ }
+ case StringType =>
+ row.getSeq[String](index).foreach { value =>
+ addValue(null, value, builder)
+ }
+ case DateType =>
+ row.getSeq[java.sql.Date](index).foreach { value =>
+ addValue(null, value, builder)
+ }
+ case TimestampType =>
+ row.getSeq[java.sql.Timestamp](index).foreach { value =>
+ addValue(null, value, builder)
+ }
+ case s: StructType => {
+ row.getSeq[Row](index).foreach { value =>
+ builder.add(null, rowToVPack(value))
+ }
+ }
+ case t: MapType => // TODO
+ case t: ArrayType => // TODO
+ case _ => // TODO
+ }
+ }
+
+ private def addValue(name: String, value: Any, builder: VPackBuilder): Unit = {
+ value match {
+ case value: Boolean => builder.add(name, java.lang.Boolean.valueOf(value))
+ case value: Double => builder.add(name, java.lang.Double.valueOf(value))
+ case value: Float => builder.add(name, java.lang.Float.valueOf(value))
+ case value: Long => builder.add(name, java.lang.Long.valueOf(value))
+ case value: Int => builder.add(name, java.lang.Integer.valueOf(value))
+ case value: Short => builder.add(name, java.lang.Short.valueOf(value))
+ case value: String => builder.add(name, java.lang.String.valueOf(value))
+ case value: java.sql.Date => builder.add(name, value)
+ case value: java.sql.Timestamp => builder.add(name, value)
+ case _ => // TODO
+ }
+ }
+
+}
diff --git a/ip-learning-spark/src/main/scala/cn/ac/iie/utils/SparkSessionUtil.scala b/ip-learning-spark/src/main/scala/cn/ac/iie/utils/SparkSessionUtil.scala
index 12cfc86..8f7661a 100644
--- a/ip-learning-spark/src/main/scala/cn/ac/iie/utils/SparkSessionUtil.scala
+++ b/ip-learning-spark/src/main/scala/cn/ac/iie/utils/SparkSessionUtil.scala
@@ -17,6 +17,9 @@ object SparkSessionUtil {
.config("spark.network.timeout", ApplicationConfig.SPARK_NETWORK_TIMEOUT)
.config("spark.sql.shuffle.partitions", ApplicationConfig.SPARK_SQL_SHUFFLE_PARTITIONS)
.config("spark.executor.memory", ApplicationConfig.SPARK_EXECUTOR_MEMORY)
+ .config("arangodb.hosts", s"${ApplicationConfig.ARANGODB_HOST}:${ApplicationConfig.ARANGODB_PORT}")
+ .config("arangodb.user", ApplicationConfig.ARANGODB_USER)
+ .config("arangodb.password", ApplicationConfig.ARANGODB_PASSWORD)
.master(ApplicationConfig.MASTER)
.getOrCreate()
LOG.warn("sparkession获取成功!!!")
diff --git a/ip-learning-spark/src/test/scala/cn/ac/iie/spark/RDDTest.scala b/ip-learning-spark/src/test/scala/cn/ac/iie/spark/RDDTest.scala
new file mode 100644
index 0000000..08cdcf3
--- /dev/null
+++ b/ip-learning-spark/src/test/scala/cn/ac/iie/spark/RDDTest.scala
@@ -0,0 +1,40 @@
+package cn.ac.iie.spark
+
+import cn.ac.iie.spark.rdd.ReadOptions
+import cn.ac.iie.utils.SparkSessionUtil
+import com.arangodb.entity.BaseDocument
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+
+object RDDTest {
+ def main(args: Array[String]): Unit = {
+
+ val sparkContext = SparkSessionUtil.spark.sparkContext
+
+ println(sparkContext.getConf.get("arangodb.hosts"))
+
+ // val options = ReadOptions("iplearn_media_domain").copy(collection = "R_LOCATE_FQDN2IP")
+ val options = ReadOptions("ip-learning-test-0")
+
+ val ipOptions = options.copy(collection = "IP")
+
+ val rdd = ArangoSpark.load[BaseDocument](sparkContext,"IP",options)
+
+ println(rdd.count())
+ println(rdd.getNumPartitions)
+
+ val value: RDD[BaseDocument] = rdd.filter(doc => doc.getAttribute("CLIENT_SESSION_COUNT").asInstanceOf[Long] > 100).map(doc => {
+ doc.addAttribute("abc", 1)
+ doc
+ })
+ value.persist(StorageLevel.MEMORY_AND_DISK)
+
+ value.foreach(row => println(row.toString))
+ println(value.count())
+
+ SparkSessionUtil.spark.close()
+ System.exit(0)
+
+ }
+
+}