summaryrefslogtreecommitdiff
path: root/modules/aggs-matrix-stats/src
diff options
context:
space:
mode:
authorMartijn van Groningen <martijn.v.groningen@gmail.com>2017-05-09 15:04:28 +0200
committerMartijn van Groningen <martijn.v.groningen@gmail.com>2017-05-10 11:06:18 +0200
commit51c74ce5476ed4d8b0f2ed3469fee955ff1e5fd1 (patch)
tree002eeca1732646f83045580dba0449e31ad78e9a /modules/aggs-matrix-stats/src
parentb24326271e6778d5d595005e7e1e4258e7e7ee24 (diff)
Added unit tests for InternalMatrixStats.
Also moved InternalAggregationTestCase to test-framework module in order to make use of it from other modules than core. Relates to #22278
Diffstat (limited to 'modules/aggs-matrix-stats/src')
-rw-r--r--modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStats.java19
-rw-r--r--modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsResults.java15
-rw-r--r--modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/RunningStats.java25
-rw-r--r--modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/BaseMatrixStatsTestCase.java124
-rw-r--r--modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStatsTests.java103
-rw-r--r--modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/MultiPassStats.java155
6 files changed, 315 insertions, 126 deletions
diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStats.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStats.java
index 0914ea2910..5b7d2cf288 100644
--- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStats.java
+++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStats.java
@@ -28,6 +28,7 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
import static java.util.Collections.emptyMap;
@@ -41,7 +42,7 @@ public class InternalMatrixStats extends InternalAggregation implements MatrixSt
private final MatrixStatsResults results;
/** per shard ctor */
- protected InternalMatrixStats(String name, long count, RunningStats multiFieldStatsResults, MatrixStatsResults results,
+ InternalMatrixStats(String name, long count, RunningStats multiFieldStatsResults, MatrixStatsResults results,
List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) {
super(name, pipelineAggregators, metaData);
assert count >= 0;
@@ -138,6 +139,10 @@ public class InternalMatrixStats extends InternalAggregation implements MatrixSt
return results.getCorrelation(fieldX, fieldY);
}
+ MatrixStatsResults getResults() {
+ return results;
+ }
+
static class Fields {
public static final String FIELDS = "fields";
public static final String NAME = "name";
@@ -238,4 +243,16 @@ public class InternalMatrixStats extends InternalAggregation implements MatrixSt
return new InternalMatrixStats(name, results.getDocCount(), runningStats, results, pipelineAggregators(), getMetaData());
}
+
+ @Override
+ protected int doHashCode() {
+ return Objects.hash(stats, results);
+ }
+
+ @Override
+ protected boolean doEquals(Object obj) {
+ InternalMatrixStats other = (InternalMatrixStats) obj;
+ return Objects.equals(this.stats, other.stats) &&
+ Objects.equals(this.results, other.results);
+ }
}
diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsResults.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsResults.java
index f82c6df73b..4da8b7ca61 100644
--- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsResults.java
+++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsResults.java
@@ -27,6 +27,7 @@ import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
+import java.util.Objects;
/**
* Descriptive stats gathered per shard. Coordinating node computes final pearson product coefficient
@@ -228,4 +229,18 @@ class MatrixStatsResults implements Writeable {
correlation.put(rowName, corRow);
}
}
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ MatrixStatsResults that = (MatrixStatsResults) o;
+ return Objects.equals(results, that.results) &&
+ Objects.equals(correlation, that.correlation);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(results, correlation);
+ }
}
diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/RunningStats.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/RunningStats.java
index 81d0d0a494..1be3279e8e 100644
--- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/RunningStats.java
+++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/RunningStats.java
@@ -28,6 +28,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
+import java.util.Objects;
/**
* Descriptive stats gathered per shard. Coordinating node computes final correlation and covariance stats
@@ -53,11 +54,11 @@ public class RunningStats implements Writeable, Cloneable {
/** covariance values */
protected HashMap<String, HashMap<String, Double>> covariances;
- public RunningStats() {
+ RunningStats() {
init();
}
- public RunningStats(final String[] fieldNames, final double[] fieldVals) {
+ RunningStats(final String[] fieldNames, final double[] fieldVals) {
if (fieldVals != null && fieldVals.length > 0) {
init();
this.add(fieldNames, fieldVals);
@@ -309,4 +310,24 @@ public class RunningStats implements Writeable, Cloneable {
throw new ElasticsearchException("Error trying to create a copy of RunningStats");
}
}
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ RunningStats that = (RunningStats) o;
+ return docCount == that.docCount &&
+ Objects.equals(fieldSum, that.fieldSum) &&
+ Objects.equals(counts, that.counts) &&
+ Objects.equals(means, that.means) &&
+ Objects.equals(variances, that.variances) &&
+ Objects.equals(skewness, that.skewness) &&
+ Objects.equals(kurtosis, that.kurtosis) &&
+ Objects.equals(covariances, that.covariances);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(docCount, fieldSum, counts, means, variances, skewness, kurtosis, covariances);
+ }
}
diff --git a/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/BaseMatrixStatsTestCase.java b/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/BaseMatrixStatsTestCase.java
index 81c9d51463..091235bf82 100644
--- a/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/BaseMatrixStatsTestCase.java
+++ b/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/BaseMatrixStatsTestCase.java
@@ -22,15 +22,12 @@ import org.elasticsearch.test.ESTestCase;
import org.junit.Before;
import java.util.ArrayList;
-import java.util.HashMap;
-
-import static org.hamcrest.Matchers.equalTo;
public abstract class BaseMatrixStatsTestCase extends ESTestCase {
protected final int numObs = atLeast(10000);
protected final ArrayList<Double> fieldA = new ArrayList<>(numObs);
protected final ArrayList<Double> fieldB = new ArrayList<>(numObs);
- protected final MultiPassStats actualStats = new MultiPassStats();
+ protected final MultiPassStats actualStats = new MultiPassStats(fieldAKey, fieldBKey);
protected static final String fieldAKey = "fieldA";
protected static final String fieldBKey = "fieldB";
@@ -47,123 +44,4 @@ public abstract class BaseMatrixStatsTestCase extends ESTestCase {
actualStats.computeStats(fieldA, fieldB);
}
- static class MultiPassStats {
- long count;
- HashMap<String, Double> means = new HashMap<>();
- HashMap<String, Double> variances = new HashMap<>();
- HashMap<String, Double> skewness = new HashMap<>();
- HashMap<String, Double> kurtosis = new HashMap<>();
- HashMap<String, HashMap<String, Double>> covariances = new HashMap<>();
- HashMap<String, HashMap<String, Double>> correlations = new HashMap<>();
-
- @SuppressWarnings("unchecked")
- void computeStats(final ArrayList<Double> fieldA, final ArrayList<Double> fieldB) {
- // set count
- count = fieldA.size();
- double meanA = 0d;
- double meanB = 0d;
-
- // compute mean
- for (int n = 0; n < count; ++n) {
- // fieldA
- meanA += fieldA.get(n);
- meanB += fieldB.get(n);
- }
- means.put(fieldAKey, meanA/count);
- means.put(fieldBKey, meanB/count);
-
- // compute variance, skewness, and kurtosis
- double dA;
- double dB;
- double skewA = 0d;
- double skewB = 0d;
- double kurtA = 0d;
- double kurtB = 0d;
- double varA = 0d;
- double varB = 0d;
- double cVar = 0d;
- for (int n = 0; n < count; ++n) {
- dA = fieldA.get(n) - means.get(fieldAKey);
- varA += dA * dA;
- skewA += dA * dA * dA;
- kurtA += dA * dA * dA * dA;
- dB = fieldB.get(n) - means.get(fieldBKey);
- varB += dB * dB;
- skewB += dB * dB * dB;
- kurtB += dB * dB * dB * dB;
- cVar += dA * dB;
- }
- variances.put(fieldAKey, varA / (count - 1));
- final double stdA = Math.sqrt(variances.get(fieldAKey));
- variances.put(fieldBKey, varB / (count - 1));
- final double stdB = Math.sqrt(variances.get(fieldBKey));
- skewness.put(fieldAKey, skewA / ((count - 1) * variances.get(fieldAKey) * stdA));
- skewness.put(fieldBKey, skewB / ((count - 1) * variances.get(fieldBKey) * stdB));
- kurtosis.put(fieldAKey, kurtA / ((count - 1) * variances.get(fieldAKey) * variances.get(fieldAKey)));
- kurtosis.put(fieldBKey, kurtB / ((count - 1) * variances.get(fieldBKey) * variances.get(fieldBKey)));
-
- // compute covariance
- final HashMap<String, Double> fieldACovar = new HashMap<>(2);
- fieldACovar.put(fieldAKey, 1d);
- cVar /= count - 1;
- fieldACovar.put(fieldBKey, cVar);
- covariances.put(fieldAKey, fieldACovar);
- final HashMap<String, Double> fieldBCovar = new HashMap<>(2);
- fieldBCovar.put(fieldAKey, cVar);
- fieldBCovar.put(fieldBKey, 1d);
- covariances.put(fieldBKey, fieldBCovar);
-
- // compute correlation
- final HashMap<String, Double> fieldACorr = new HashMap<>();
- fieldACorr.put(fieldAKey, 1d);
- double corr = covariances.get(fieldAKey).get(fieldBKey);
- corr /= stdA * stdB;
- fieldACorr.put(fieldBKey, corr);
- correlations.put(fieldAKey, fieldACorr);
- final HashMap<String, Double> fieldBCorr = new HashMap<>();
- fieldBCorr.put(fieldAKey, corr);
- fieldBCorr.put(fieldBKey, 1d);
- correlations.put(fieldBKey, fieldBCorr);
- }
-
- public void assertNearlyEqual(MatrixStatsResults stats) {
- assertThat(count, equalTo(stats.getDocCount()));
- assertThat(count, equalTo(stats.getFieldCount(fieldAKey)));
- assertThat(count, equalTo(stats.getFieldCount(fieldBKey)));
- // means
- assertTrue(nearlyEqual(means.get(fieldAKey), stats.getMean(fieldAKey), 1e-7));
- assertTrue(nearlyEqual(means.get(fieldBKey), stats.getMean(fieldBKey), 1e-7));
- // variances
- assertTrue(nearlyEqual(variances.get(fieldAKey), stats.getVariance(fieldAKey), 1e-7));
- assertTrue(nearlyEqual(variances.get(fieldBKey), stats.getVariance(fieldBKey), 1e-7));
- // skewness (multi-pass is more susceptible to round-off error so we need to slightly relax the tolerance)
- assertTrue(nearlyEqual(skewness.get(fieldAKey), stats.getSkewness(fieldAKey), 1e-4));
- assertTrue(nearlyEqual(skewness.get(fieldBKey), stats.getSkewness(fieldBKey), 1e-4));
- // kurtosis (multi-pass is more susceptible to round-off error so we need to slightly relax the tolerance)
- assertTrue(nearlyEqual(kurtosis.get(fieldAKey), stats.getKurtosis(fieldAKey), 1e-4));
- assertTrue(nearlyEqual(kurtosis.get(fieldBKey), stats.getKurtosis(fieldBKey), 1e-4));
- // covariances
- assertTrue(nearlyEqual(covariances.get(fieldAKey).get(fieldBKey), stats.getCovariance(fieldAKey, fieldBKey), 1e-7));
- assertTrue(nearlyEqual(covariances.get(fieldBKey).get(fieldAKey), stats.getCovariance(fieldBKey, fieldAKey), 1e-7));
- // correlation
- assertTrue(nearlyEqual(correlations.get(fieldAKey).get(fieldBKey), stats.getCorrelation(fieldAKey, fieldBKey), 1e-7));
- assertTrue(nearlyEqual(correlations.get(fieldBKey).get(fieldAKey), stats.getCorrelation(fieldBKey, fieldAKey), 1e-7));
- }
- }
-
- private static boolean nearlyEqual(double a, double b, double epsilon) {
- final double absA = Math.abs(a);
- final double absB = Math.abs(b);
- final double diff = Math.abs(a - b);
-
- if (a == b) { // shortcut, handles infinities
- return true;
- } else if (a == 0 || b == 0 || diff < Double.MIN_NORMAL) {
- // a or b is zero or both are extremely close to it
- // relative error is less meaningful here
- return diff < (epsilon * Double.MIN_NORMAL);
- } else { // use relative error
- return diff / Math.min((absA + absB), Double.MAX_VALUE) < epsilon;
- }
- }
}
diff --git a/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStatsTests.java b/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStatsTests.java
new file mode 100644
index 0000000000..277006da90
--- /dev/null
+++ b/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStatsTests.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.search.aggregations.matrix.stats;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.MockBigArrays;
+import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
+import org.elasticsearch.script.ScriptService;
+import org.elasticsearch.search.aggregations.InternalAggregation;
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
+import org.elasticsearch.test.InternalAggregationTestCase;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+
+public class InternalMatrixStatsTests extends InternalAggregationTestCase<InternalMatrixStats> {
+
+ @Override
+ protected InternalMatrixStats createTestInstance(String name, List<PipelineAggregator> pipelineAggregators,
+ Map<String, Object> metaData) {
+ int numFields = randomInt(128);
+ String[] fieldNames = new String[numFields];
+ double[] fieldValues = new double[numFields];
+ for (int i = 0; i < numFields; i++) {
+ fieldNames[i] = Integer.toString(i);
+ fieldValues[i] = randomDouble();
+ }
+ RunningStats runningStats = new RunningStats();
+ runningStats.add(fieldNames, fieldValues);
+ MatrixStatsResults matrixStatsResults = randomBoolean() ? new MatrixStatsResults(runningStats) : null;
+ return new InternalMatrixStats("_name", 1L, runningStats, matrixStatsResults, Collections.emptyList(), Collections.emptyMap());
+ }
+
+ @Override
+ protected Writeable.Reader<InternalMatrixStats> instanceReader() {
+ return InternalMatrixStats::new;
+ }
+
+ @Override
+ public void testReduceRandom() {
+ int numValues = 10000;
+ int numShards = randomIntBetween(1, 20);
+ int valuesPerShard = (int) Math.floor(numValues / numShards);
+
+ List<Double> aValues = new ArrayList<>();
+ List<Double> bValues = new ArrayList<>();
+
+ RunningStats runningStats = new RunningStats();
+ List<InternalAggregation> shardResults = new ArrayList<>();
+
+ int valuePerShardCounter = 0;
+ for (int i = 0; i < numValues; i++) {
+ double valueA = randomDouble();
+ aValues.add(valueA);
+ double valueB = randomDouble();
+ bValues.add(valueB);
+
+ runningStats.add(new String[]{"a", "b"}, new double[]{valueA, valueB});
+ if (++valuePerShardCounter == valuesPerShard) {
+ shardResults.add(new InternalMatrixStats("_name", 1L, runningStats, null, Collections.emptyList(), Collections.emptyMap()));
+ runningStats = new RunningStats();
+ valuePerShardCounter = 0;
+ }
+ }
+
+ if (valuePerShardCounter != 0) {
+ shardResults.add(new InternalMatrixStats("_name", 1L, runningStats, null, Collections.emptyList(), Collections.emptyMap()));
+ }
+ MultiPassStats multiPassStats = new MultiPassStats("a", "b");
+ multiPassStats.computeStats(aValues, bValues);
+
+ ScriptService mockScriptService = mockScriptService();
+ MockBigArrays bigArrays = new MockBigArrays(Settings.EMPTY, new NoneCircuitBreakerService());
+ InternalAggregation.ReduceContext context =
+ new InternalAggregation.ReduceContext(bigArrays, mockScriptService, true);
+ InternalMatrixStats reduced = (InternalMatrixStats) shardResults.get(0).reduce(shardResults, context);
+ multiPassStats.assertNearlyEqual(reduced.getResults());
+ }
+
+ @Override
+ protected void assertReduced(InternalMatrixStats reduced, List<InternalMatrixStats> inputs) {
+ throw new UnsupportedOperationException();
+ }
+}
diff --git a/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/MultiPassStats.java b/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/MultiPassStats.java
new file mode 100644
index 0000000000..70e2172ce9
--- /dev/null
+++ b/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/MultiPassStats.java
@@ -0,0 +1,155 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.search.aggregations.matrix.stats;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+class MultiPassStats {
+
+ private final String fieldAKey;
+ private final String fieldBKey;
+
+ private long count;
+ private Map<String, Double> means = new HashMap<>();
+ private Map<String, Double> variances = new HashMap<>();
+ private Map<String, Double> skewness = new HashMap<>();
+ private Map<String, Double> kurtosis = new HashMap<>();
+ private Map<String, HashMap<String, Double>> covariances = new HashMap<>();
+ private Map<String, HashMap<String, Double>> correlations = new HashMap<>();
+
+ MultiPassStats(String fieldAName, String fieldBName) {
+ this.fieldAKey = fieldAName;
+ this.fieldBKey = fieldBName;
+ }
+
+ @SuppressWarnings("unchecked")
+ void computeStats(final List<Double> fieldA, final List<Double> fieldB) {
+ // set count
+ count = fieldA.size();
+ double meanA = 0d;
+ double meanB = 0d;
+
+ // compute mean
+ for (int n = 0; n < count; ++n) {
+ // fieldA
+ meanA += fieldA.get(n);
+ meanB += fieldB.get(n);
+ }
+ means.put(fieldAKey, meanA / count);
+ means.put(fieldBKey, meanB / count);
+
+ // compute variance, skewness, and kurtosis
+ double dA;
+ double dB;
+ double skewA = 0d;
+ double skewB = 0d;
+ double kurtA = 0d;
+ double kurtB = 0d;
+ double varA = 0d;
+ double varB = 0d;
+ double cVar = 0d;
+ for (int n = 0; n < count; ++n) {
+ dA = fieldA.get(n) - means.get(fieldAKey);
+ varA += dA * dA;
+ skewA += dA * dA * dA;
+ kurtA += dA * dA * dA * dA;
+ dB = fieldB.get(n) - means.get(fieldBKey);
+ varB += dB * dB;
+ skewB += dB * dB * dB;
+ kurtB += dB * dB * dB * dB;
+ cVar += dA * dB;
+ }
+ variances.put(fieldAKey, varA / (count - 1));
+ final double stdA = Math.sqrt(variances.get(fieldAKey));
+ variances.put(fieldBKey, varB / (count - 1));
+ final double stdB = Math.sqrt(variances.get(fieldBKey));
+ skewness.put(fieldAKey, skewA / ((count - 1) * variances.get(fieldAKey) * stdA));
+ skewness.put(fieldBKey, skewB / ((count - 1) * variances.get(fieldBKey) * stdB));
+ kurtosis.put(fieldAKey, kurtA / ((count - 1) * variances.get(fieldAKey) * variances.get(fieldAKey)));
+ kurtosis.put(fieldBKey, kurtB / ((count - 1) * variances.get(fieldBKey) * variances.get(fieldBKey)));
+
+ // compute covariance
+ final HashMap<String, Double> fieldACovar = new HashMap<>(2);
+ fieldACovar.put(fieldAKey, 1d);
+ cVar /= count - 1;
+ fieldACovar.put(fieldBKey, cVar);
+ covariances.put(fieldAKey, fieldACovar);
+ final HashMap<String, Double> fieldBCovar = new HashMap<>(2);
+ fieldBCovar.put(fieldAKey, cVar);
+ fieldBCovar.put(fieldBKey, 1d);
+ covariances.put(fieldBKey, fieldBCovar);
+
+ // compute correlation
+ final HashMap<String, Double> fieldACorr = new HashMap<>();
+ fieldACorr.put(fieldAKey, 1d);
+ double corr = covariances.get(fieldAKey).get(fieldBKey);
+ corr /= stdA * stdB;
+ fieldACorr.put(fieldBKey, corr);
+ correlations.put(fieldAKey, fieldACorr);
+ final HashMap<String, Double> fieldBCorr = new HashMap<>();
+ fieldBCorr.put(fieldAKey, corr);
+ fieldBCorr.put(fieldBKey, 1d);
+ correlations.put(fieldBKey, fieldBCorr);
+ }
+
+ void assertNearlyEqual(MatrixStatsResults stats) {
+ assertEquals(count, stats.getDocCount());
+ assertEquals(count, stats.getFieldCount(fieldAKey));
+ assertEquals(count, stats.getFieldCount(fieldBKey));
+ // means
+ assertTrue(nearlyEqual(means.get(fieldAKey), stats.getMean(fieldAKey), 1e-7));
+ assertTrue(nearlyEqual(means.get(fieldBKey), stats.getMean(fieldBKey), 1e-7));
+ // variances
+ assertTrue(nearlyEqual(variances.get(fieldAKey), stats.getVariance(fieldAKey), 1e-7));
+ assertTrue(nearlyEqual(variances.get(fieldBKey), stats.getVariance(fieldBKey), 1e-7));
+ // skewness (multi-pass is more susceptible to round-off error so we need to slightly relax the tolerance)
+ assertTrue(nearlyEqual(skewness.get(fieldAKey), stats.getSkewness(fieldAKey), 1e-4));
+ assertTrue(nearlyEqual(skewness.get(fieldBKey), stats.getSkewness(fieldBKey), 1e-4));
+ // kurtosis (multi-pass is more susceptible to round-off error so we need to slightly relax the tolerance)
+ assertTrue(nearlyEqual(kurtosis.get(fieldAKey), stats.getKurtosis(fieldAKey), 1e-4));
+ assertTrue(nearlyEqual(kurtosis.get(fieldBKey), stats.getKurtosis(fieldBKey), 1e-4));
+ // covariances
+ assertTrue(nearlyEqual(covariances.get(fieldAKey).get(fieldBKey),stats.getCovariance(fieldAKey, fieldBKey), 1e-7));
+ assertTrue(nearlyEqual(covariances.get(fieldBKey).get(fieldAKey),stats.getCovariance(fieldBKey, fieldAKey), 1e-7));
+ // correlation
+ assertTrue(nearlyEqual(correlations.get(fieldAKey).get(fieldBKey), stats.getCorrelation(fieldAKey, fieldBKey), 1e-7));
+ assertTrue(nearlyEqual(correlations.get(fieldBKey).get(fieldAKey), stats.getCorrelation(fieldBKey, fieldAKey), 1e-7));
+ }
+
+ private static boolean nearlyEqual(double a, double b, double epsilon) {
+ final double absA = Math.abs(a);
+ final double absB = Math.abs(b);
+ final double diff = Math.abs(a - b);
+
+ if (a == b) { // shortcut, handles infinities
+ return true;
+ } else if (a == 0 || b == 0 || diff < Double.MIN_NORMAL) {
+ // a or b is zero or both are extremely close to it
+ // relative error is less meaningful here
+ return diff < (epsilon * Double.MIN_NORMAL);
+ } else { // use relative error
+ return diff / Math.min((absA + absB), Double.MAX_VALUE) < epsilon;
+ }
+ }
+}