aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala179
1 files changed, 75 insertions, 104 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
index 65328df17b..b7072728d4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
@@ -17,19 +17,19 @@
package org.apache.spark.ml.clustering
-import scala.collection.mutable
-
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
-import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
+import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types._
class PowerIterationClusteringSuite extends SparkFunSuite
with MLlibTestSparkContext with DefaultReadWriteTest {
+ import testImplicits._
+
@transient var data: Dataset[_] = _
final val r1 = 1.0
final val n1 = 10
@@ -48,10 +48,9 @@ class PowerIterationClusteringSuite extends SparkFunSuite
assert(pic.getK === 2)
assert(pic.getMaxIter === 20)
assert(pic.getInitMode === "random")
- assert(pic.getPredictionCol === "prediction")
- assert(pic.getIdCol === "id")
- assert(pic.getNeighborsCol === "neighbors")
- assert(pic.getSimilaritiesCol === "similarities")
+ assert(pic.getSrcCol === "src")
+ assert(pic.getDstCol === "dst")
+ assert(!pic.isDefined(pic.weightCol))
}
test("parameter validation") {
@@ -62,125 +61,102 @@ class PowerIterationClusteringSuite extends SparkFunSuite
new PowerIterationClustering().setInitMode("no_such_a_mode")
}
intercept[IllegalArgumentException] {
- new PowerIterationClustering().setIdCol("")
+ new PowerIterationClustering().setSrcCol("")
}
intercept[IllegalArgumentException] {
- new PowerIterationClustering().setNeighborsCol("")
- }
- intercept[IllegalArgumentException] {
- new PowerIterationClustering().setSimilaritiesCol("")
+ new PowerIterationClustering().setDstCol("")
}
}
test("power iteration clustering") {
val n = n1 + n2
- val model = new PowerIterationClustering()
+ val assignments = new PowerIterationClustering()
.setK(2)
.setMaxIter(40)
- val result = model.transform(data)
-
- val predictions = Array.fill(2)(mutable.Set.empty[Long])
- result.select("id", "prediction").collect().foreach {
- case Row(id: Long, cluster: Integer) => predictions(cluster) += id
- }
- assert(predictions.toSet == Set((1 until n1).toSet, (n1 until n).toSet))
-
- val result2 = new PowerIterationClustering()
+ .setWeightCol("weight")
+ .assignClusters(data)
+ val localAssignments = assignments
+ .select('id, 'cluster)
+ .as[(Long, Int)].collect().toSet
+ val expectedResult = (0 until n1).map(x => (x, 1)).toSet ++
+ (n1 until n).map(x => (x, 0)).toSet
+ assert(localAssignments === expectedResult)
+
+ val assignments2 = new PowerIterationClustering()
.setK(2)
.setMaxIter(10)
.setInitMode("degree")
- .transform(data)
- val predictions2 = Array.fill(2)(mutable.Set.empty[Long])
- result2.select("id", "prediction").collect().foreach {
- case Row(id: Long, cluster: Integer) => predictions2(cluster) += id
- }
- assert(predictions2.toSet == Set((1 until n1).toSet, (n1 until n).toSet))
+ .setWeightCol("weight")
+ .assignClusters(data)
+ val localAssignments2 = assignments2
+ .select('id, 'cluster)
+ .as[(Long, Int)].collect().toSet
+ assert(localAssignments2 === expectedResult)
}
test("supported input types") {
- val model = new PowerIterationClustering()
+ val pic = new PowerIterationClustering()
.setK(2)
.setMaxIter(1)
+ .setWeightCol("weight")
- def runTest(idType: DataType, neighborType: DataType, similarityType: DataType): Unit = {
+ def runTest(srcType: DataType, dstType: DataType, weightType: DataType): Unit = {
val typedData = data.select(
- col("id").cast(idType).alias("id"),
- col("neighbors").cast(ArrayType(neighborType, containsNull = false)).alias("neighbors"),
- col("similarities").cast(ArrayType(similarityType, containsNull = false))
- .alias("similarities")
+ col("src").cast(srcType).alias("src"),
+ col("dst").cast(dstType).alias("dst"),
+ col("weight").cast(weightType).alias("weight")
)
- model.transform(typedData).collect()
- }
-
- for (idType <- Seq(IntegerType, LongType)) {
- runTest(idType, LongType, DoubleType)
- }
- for (neighborType <- Seq(IntegerType, LongType)) {
- runTest(LongType, neighborType, DoubleType)
- }
- for (similarityType <- Seq(FloatType, DoubleType)) {
- runTest(LongType, LongType, similarityType)
+ pic.assignClusters(typedData).collect()
}
- }
- test("invalid input: wrong types") {
- val model = new PowerIterationClustering()
- .setK(2)
- .setMaxIter(1)
- intercept[IllegalArgumentException] {
- val typedData = data.select(
- col("id").cast(DoubleType).alias("id"),
- col("neighbors"),
- col("similarities")
- )
- model.transform(typedData)
+ for (srcType <- Seq(IntegerType, LongType)) {
+ runTest(srcType, LongType, DoubleType)
}
- intercept[IllegalArgumentException] {
- val typedData = data.select(
- col("id"),
- col("neighbors").cast(ArrayType(DoubleType, containsNull = false)).alias("neighbors"),
- col("similarities")
- )
- model.transform(typedData)
+ for (dstType <- Seq(IntegerType, LongType)) {
+ runTest(LongType, dstType, DoubleType)
}
- intercept[IllegalArgumentException] {
- val typedData = data.select(
- col("id"),
- col("neighbors"),
- col("neighbors").alias("similarities")
- )
- model.transform(typedData)
+ for (weightType <- Seq(FloatType, DoubleType)) {
+ runTest(LongType, LongType, weightType)
}
}
test("invalid input: negative similarity") {
- val model = new PowerIterationClustering()
+ val pic = new PowerIterationClustering()
.setMaxIter(1)
+ .setWeightCol("weight")
val badData = spark.createDataFrame(Seq(
- (0, Array(1), Array(-1.0)),
- (1, Array(0), Array(-1.0))
- )).toDF("id", "neighbors", "similarities")
+ (0, 1, -1.0),
+ (1, 0, -1.0)
+ )).toDF("src", "dst", "weight")
val msg = intercept[SparkException] {
- model.transform(badData)
+ pic.assignClusters(badData)
}.getCause.getMessage
assert(msg.contains("Similarity must be nonnegative"))
}
- test("invalid input: mismatched lengths for neighbor and similarity arrays") {
- val model = new PowerIterationClustering()
- .setMaxIter(1)
- val badData = spark.createDataFrame(Seq(
- (0, Array(1), Array(0.5)),
- (1, Array(0, 2), Array(0.5)),
- (2, Array(1), Array(0.5))
- )).toDF("id", "neighbors", "similarities")
- val msg = intercept[SparkException] {
- model.transform(badData)
- }.getCause.getMessage
- assert(msg.contains("The length of the neighbor ID list must be equal to the the length of " +
- "the neighbor similarity list."))
- assert(msg.contains(s"Row for ID ${model.getIdCol}=1"))
+ test("test default weight") {
+ val dataWithoutWeight = data.sample(0.5, 1L).select('src, 'dst)
+
+ val assignments = new PowerIterationClustering()
+ .setK(2)
+ .setMaxIter(40)
+ .assignClusters(dataWithoutWeight)
+ val localAssignments = assignments
+ .select('id, 'cluster)
+ .as[(Long, Int)].collect().toSet
+
+ val dataWithWeightOne = dataWithoutWeight.withColumn("weight", lit(1.0))
+
+ val assignments2 = new PowerIterationClustering()
+ .setK(2)
+ .setMaxIter(40)
+ .assignClusters(dataWithWeightOne)
+ val localAssignments2 = assignments2
+ .select('id, 'cluster)
+ .as[(Long, Int)].collect().toSet
+
+ assert(localAssignments === localAssignments2)
}
test("read/write") {
@@ -188,10 +164,9 @@ class PowerIterationClusteringSuite extends SparkFunSuite
.setK(4)
.setMaxIter(100)
.setInitMode("degree")
- .setIdCol("test_id")
- .setNeighborsCol("myNeighborsCol")
- .setSimilaritiesCol("mySimilaritiesCol")
- .setPredictionCol("test_prediction")
+ .setSrcCol("src1")
+ .setDstCol("dst1")
+ .setWeightCol("weight")
testDefaultReadWrite(t)
}
}
@@ -222,17 +197,13 @@ object PowerIterationClusteringSuite {
val n = n1 + n2
val points = genCircle(r1, n1) ++ genCircle(r2, n2)
- val rows = for (i <- 1 until n) yield {
- val neighbors = for (j <- 0 until i) yield {
- j.toLong
+ val rows = (for (i <- 1 until n) yield {
+ for (j <- 0 until i) yield {
+ (i.toLong, j.toLong, sim(points(i), points(j)))
}
- val similarities = for (j <- 0 until i) yield {
- sim(points(i), points(j))
- }
- (i.toLong, neighbors.toArray, similarities.toArray)
- }
+ }).flatMap(_.iterator)
- spark.createDataFrame(rows).toDF("id", "neighbors", "similarities")
+ spark.createDataFrame(rows).toDF("src", "dst", "weight")
}
}