139 lines
6.9 KiB
Scala
139 lines
6.9 KiB
Scala
/*
|
|
* DISCLAIMER
|
|
*
|
|
* Copyright 2016 ArangoDB GmbH, Cologne, Germany
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*
|
|
* Copyright holder is ArangoDB GmbH, Cologne, Germany
|
|
*
|
|
* author Mark - mark at arangodb.com
|
|
*/
|
|
|
|
package cn.ac.iie
|
|
|
|
import java.io.FileInputStream
|
|
import java.security.KeyStore
|
|
|
|
import cn.ac.iie.spark.rdd.{ArangoOptions, ReadOptions, WriteOptions}
|
|
import com.arangodb.{ArangoDB, ArangoDBException, Protocol}
|
|
import com.arangodb.entity.LoadBalancingStrategy
|
|
import com.arangodb.velocypack.module.jdk8.VPackJdk8Module
|
|
import com.arangodb.velocypack.module.scala.VPackScalaModule
|
|
import javax.net.ssl.{KeyManagerFactory, SSLContext, TrustManagerFactory}
|
|
import org.apache.spark.SparkConf
|
|
|
|
import scala.util.Try
|
|
|
|
package object spark {
|
|
|
|
val PropertyHosts = "arangodb.hosts"
|
|
val PropertyUser = "arangodb.user"
|
|
val PropertyPassword = "arangodb.password"
|
|
val PropertyUseSsl = "arangodb.useSsl"
|
|
val PropertySslKeyStoreFile = "arangodb.ssl.keyStoreFile"
|
|
val PropertySslPassPhrase = "arangodb.ssl.passPhrase"
|
|
val PropertySslProtocol = "arangodb.ssl.protocol"
|
|
val PropertyProtocol = "arangodb.protocol"
|
|
val PropertyMaxConnections = "arangodb.maxConnections"
|
|
val PropertyAcquireHostList = "arangodb.acquireHostList"
|
|
val PropertyAcquireHostListInterval = "arangodb.acquireHostListInterval"
|
|
val PropertyLoadBalancingStrategy = "arangodb.loadBalancingStrategy"
|
|
|
|
private[spark] def createReadOptions(options: ReadOptions, sc: SparkConf): ReadOptions = {
|
|
options.copy(
|
|
hosts = options.hosts.orElse(some(sc.get(PropertyHosts, null))),
|
|
user = options.user.orElse(some(sc.get(PropertyUser, null))),
|
|
password = options.password.orElse(some(sc.get(PropertyPassword, null))),
|
|
useSsl = options.useSsl.orElse(some(Try(sc.get(PropertyUseSsl, null).toBoolean).getOrElse(false))),
|
|
sslKeyStoreFile = options.sslKeyStoreFile.orElse(some(sc.get(PropertySslKeyStoreFile, null))),
|
|
sslPassPhrase = options.sslPassPhrase.orElse(some(sc.get(PropertySslPassPhrase, null))),
|
|
sslProtocol = options.sslProtocol.orElse(some(sc.get(PropertySslProtocol, null))),
|
|
protocol = options.protocol.orElse(some(Protocol.valueOf(sc.get(PropertyProtocol, "VST")))),
|
|
maxConnections = options.maxConnections.orElse(some(Try(sc.get(PropertyMaxConnections, null).toInt).getOrElse(1))),
|
|
acquireHostList = options.acquireHostList.orElse(some(Try(sc.get(PropertyAcquireHostList, null).toBoolean).getOrElse(false))),
|
|
acquireHostListInterval = options.acquireHostListInterval.orElse(some(Try(sc.get(PropertyAcquireHostListInterval, null).toInt).getOrElse(60000))),
|
|
loadBalancingStrategy = options.loadBalancingStrategy.orElse(some(LoadBalancingStrategy.valueOf(sc.get(PropertyLoadBalancingStrategy, "NONE")))))
|
|
}
|
|
|
|
private[spark] def createWriteOptions(options: WriteOptions, sc: SparkConf): WriteOptions = {
|
|
options.copy(
|
|
hosts = options.hosts.orElse(some(sc.get(PropertyHosts, null))),
|
|
user = options.user.orElse(some(sc.get(PropertyUser, null))),
|
|
password = options.password.orElse(some(sc.get(PropertyPassword, null))),
|
|
useSsl = options.useSsl.orElse(some(Try(sc.get(PropertyUseSsl, null).toBoolean).getOrElse(false))),
|
|
sslKeyStoreFile = options.sslKeyStoreFile.orElse(some(sc.get(PropertySslKeyStoreFile, null))),
|
|
sslPassPhrase = options.sslPassPhrase.orElse(some(sc.get(PropertySslPassPhrase, null))),
|
|
sslProtocol = options.sslProtocol.orElse(some(sc.get(PropertySslProtocol, null))),
|
|
protocol = options.protocol.orElse(some(Protocol.valueOf(sc.get(PropertyProtocol, "VST")))),
|
|
maxConnections = options.maxConnections.orElse(some(Try(sc.get(PropertyMaxConnections, null).toInt).getOrElse(1))),
|
|
acquireHostList = options.acquireHostList.orElse(some(Try(sc.get(PropertyAcquireHostList, null).toBoolean).getOrElse(false))),
|
|
acquireHostListInterval = options.acquireHostListInterval.orElse(some(Try(sc.get(PropertyAcquireHostListInterval, null).toInt).getOrElse(60000))),
|
|
loadBalancingStrategy = options.loadBalancingStrategy.orElse(some(LoadBalancingStrategy.valueOf(sc.get(PropertyLoadBalancingStrategy, "NONE")))))
|
|
}
|
|
|
|
private[spark] def createArangoBuilder(options: ArangoOptions): ArangoDB.Builder = {
|
|
val builder = new ArangoDB.Builder()
|
|
builder.registerModules(new VPackJdk8Module, new VPackScalaModule)
|
|
options.hosts.foreach { hosts(_).foreach(host => builder.host(host._1, host._2)) }
|
|
options.user.foreach { builder.user(_) }
|
|
options.password.foreach { builder.password(_) }
|
|
options.useSsl.foreach { builder.useSsl(_) }
|
|
if (options.sslKeyStoreFile.isDefined && options.sslPassPhrase.isDefined) {
|
|
builder.sslContext(createSslContext(options.sslKeyStoreFile.get, options.sslPassPhrase.get, options.sslProtocol.getOrElse("TLS")))
|
|
}
|
|
options.protocol.foreach { builder.useProtocol(_) }
|
|
options.maxConnections.foreach { builder.maxConnections(_) }
|
|
options.acquireHostList.foreach { builder.acquireHostList(_) }
|
|
options.acquireHostListInterval.foreach { builder.acquireHostListInterval(_) }
|
|
options.loadBalancingStrategy.foreach { builder.loadBalancingStrategy(_) }
|
|
builder
|
|
}
|
|
|
|
private def createSslContext(keyStoreFile: String, passPhrase: String, protocol: String): SSLContext = {
|
|
val ks = KeyStore.getInstance(KeyStore.getDefaultType());
|
|
val kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
|
|
ks.load(new FileInputStream(keyStoreFile), passPhrase.toCharArray());
|
|
kmf.init(ks, passPhrase.toCharArray());
|
|
val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
|
|
tmf.init(ks);
|
|
val sc = SSLContext.getInstance(protocol);
|
|
sc.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
|
|
sc
|
|
}
|
|
|
|
private def some(value: String): Option[String] =
|
|
if (value != null) Some(value) else None
|
|
|
|
private def some(value: Int): Option[Int] =
|
|
Some(value)
|
|
|
|
private def some(value: Boolean): Option[Boolean] =
|
|
Some(value)
|
|
|
|
private def some(value: Protocol): Option[Protocol] =
|
|
Some(value)
|
|
|
|
private def some(value: LoadBalancingStrategy): Option[LoadBalancingStrategy] =
|
|
Some(value)
|
|
|
|
private def hosts(hosts: String): List[(String, Int)] =
|
|
hosts.split(",").map({ x =>
|
|
val s = x.split(":")
|
|
if (s.length != 2 || !s(1).matches("[0-9]+"))
|
|
throw new ArangoDBException(s"Could not load property-value arangodb.hosts=${s}. Expected format ip:port,ip:port,...");
|
|
else
|
|
(s(0), s(1).toInt)
|
|
}).toList
|
|
|
|
} |