diff options
author | Martijn van Groningen <martijn.v.groningen@gmail.com> | 2017-05-09 15:04:28 +0200 |
---|---|---|
committer | Martijn van Groningen <martijn.v.groningen@gmail.com> | 2017-05-10 11:06:18 +0200 |
commit | 51c74ce5476ed4d8b0f2ed3469fee955ff1e5fd1 (patch) | |
tree | 002eeca1732646f83045580dba0449e31ad78e9a /modules/aggs-matrix-stats/src | |
parent | b24326271e6778d5d595005e7e1e4258e7e7ee24 (diff) |
Added unit tests for InternalMatrixStats.
Also moved InternalAggregationTestCase to test-framework module in order to make use of it from other modules than core.
Relates to #22278
Diffstat (limited to 'modules/aggs-matrix-stats/src')
6 files changed, 315 insertions, 126 deletions
diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStats.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStats.java index 0914ea2910..5b7d2cf288 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStats.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStats.java @@ -28,6 +28,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Objects; import static java.util.Collections.emptyMap; @@ -41,7 +42,7 @@ public class InternalMatrixStats extends InternalAggregation implements MatrixSt private final MatrixStatsResults results; /** per shard ctor */ - protected InternalMatrixStats(String name, long count, RunningStats multiFieldStatsResults, MatrixStatsResults results, + InternalMatrixStats(String name, long count, RunningStats multiFieldStatsResults, MatrixStatsResults results, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) { super(name, pipelineAggregators, metaData); assert count >= 0; @@ -138,6 +139,10 @@ public class InternalMatrixStats extends InternalAggregation implements MatrixSt return results.getCorrelation(fieldX, fieldY); } + MatrixStatsResults getResults() { + return results; + } + static class Fields { public static final String FIELDS = "fields"; public static final String NAME = "name"; @@ -238,4 +243,16 @@ public class InternalMatrixStats extends InternalAggregation implements MatrixSt return new InternalMatrixStats(name, results.getDocCount(), runningStats, results, pipelineAggregators(), getMetaData()); } + + @Override + protected int doHashCode() { + return Objects.hash(stats, results); + } + + @Override + protected boolean doEquals(Object obj) { + InternalMatrixStats other = (InternalMatrixStats) obj; + return Objects.equals(this.stats, other.stats) && + Objects.equals(this.results, other.results); + } } diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsResults.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsResults.java index f82c6df73b..4da8b7ca61 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsResults.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsResults.java @@ -27,6 +27,7 @@ import java.io.IOException; import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.Objects; /** * Descriptive stats gathered per shard. Coordinating node computes final pearson product coefficient @@ -228,4 +229,18 @@ class MatrixStatsResults implements Writeable { correlation.put(rowName, corRow); } } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MatrixStatsResults that = (MatrixStatsResults) o; + return Objects.equals(results, that.results) && + Objects.equals(correlation, that.correlation); + } + + @Override + public int hashCode() { + return Objects.hash(results, correlation); + } } diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/RunningStats.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/RunningStats.java index 81d0d0a494..1be3279e8e 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/RunningStats.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/RunningStats.java @@ -28,6 +28,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import java.util.Objects; /** * Descriptive stats gathered per shard. Coordinating node computes final correlation and covariance stats @@ -53,11 +54,11 @@ public class RunningStats implements Writeable, Cloneable { /** covariance values */ protected HashMap<String, HashMap<String, Double>> covariances; - public RunningStats() { + RunningStats() { init(); } - public RunningStats(final String[] fieldNames, final double[] fieldVals) { + RunningStats(final String[] fieldNames, final double[] fieldVals) { if (fieldVals != null && fieldVals.length > 0) { init(); this.add(fieldNames, fieldVals); @@ -309,4 +310,24 @@ public class RunningStats implements Writeable, Cloneable { throw new ElasticsearchException("Error trying to create a copy of RunningStats"); } } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RunningStats that = (RunningStats) o; + return docCount == that.docCount && + Objects.equals(fieldSum, that.fieldSum) && + Objects.equals(counts, that.counts) && + Objects.equals(means, that.means) && + Objects.equals(variances, that.variances) && + Objects.equals(skewness, that.skewness) && + Objects.equals(kurtosis, that.kurtosis) && + Objects.equals(covariances, that.covariances); + } + + @Override + public int hashCode() { + return Objects.hash(docCount, fieldSum, counts, means, variances, skewness, kurtosis, covariances); + } } diff --git a/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/BaseMatrixStatsTestCase.java b/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/BaseMatrixStatsTestCase.java index 81c9d51463..091235bf82 100644 --- a/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/BaseMatrixStatsTestCase.java +++ b/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/BaseMatrixStatsTestCase.java @@ -22,15 +22,12 @@ import org.elasticsearch.test.ESTestCase; import org.junit.Before; import java.util.ArrayList; -import java.util.HashMap; - -import static org.hamcrest.Matchers.equalTo; public abstract class BaseMatrixStatsTestCase extends ESTestCase { protected final int numObs = atLeast(10000); protected final ArrayList<Double> fieldA = new ArrayList<>(numObs); protected final ArrayList<Double> fieldB = new ArrayList<>(numObs); - protected final MultiPassStats actualStats = new MultiPassStats(); + protected final MultiPassStats actualStats = new MultiPassStats(fieldAKey, fieldBKey); protected static final String fieldAKey = "fieldA"; protected static final String fieldBKey = "fieldB"; @@ -47,123 +44,4 @@ public abstract class BaseMatrixStatsTestCase extends ESTestCase { actualStats.computeStats(fieldA, fieldB); } - static class MultiPassStats { - long count; - HashMap<String, Double> means = new HashMap<>(); - HashMap<String, Double> variances = new HashMap<>(); - HashMap<String, Double> skewness = new HashMap<>(); - HashMap<String, Double> kurtosis = new HashMap<>(); - HashMap<String, HashMap<String, Double>> covariances = new HashMap<>(); - HashMap<String, HashMap<String, Double>> correlations = new HashMap<>(); - - @SuppressWarnings("unchecked") - void computeStats(final ArrayList<Double> fieldA, final ArrayList<Double> fieldB) { - // set count - count = fieldA.size(); - double meanA = 0d; - double meanB = 0d; - - // compute mean - for (int n = 0; n < count; ++n) { - // fieldA - meanA += fieldA.get(n); - meanB += fieldB.get(n); - } - means.put(fieldAKey, meanA/count); - means.put(fieldBKey, meanB/count); - - // compute variance, skewness, and kurtosis - double dA; - double dB; - double skewA = 0d; - double skewB = 0d; - double kurtA = 0d; - double kurtB = 0d; - double varA = 0d; - double varB = 0d; - double cVar = 0d; - for (int n = 0; n < count; ++n) { - dA = fieldA.get(n) - means.get(fieldAKey); - varA += dA * dA; - skewA += dA * dA * dA; - kurtA += dA * dA * dA * dA; - dB = fieldB.get(n) - means.get(fieldBKey); - varB += dB * dB; - skewB += dB * dB * dB; - kurtB += dB * dB * dB * dB; - cVar += dA * dB; - } - variances.put(fieldAKey, varA / (count - 1)); - final double stdA = Math.sqrt(variances.get(fieldAKey)); - variances.put(fieldBKey, varB / (count - 1)); - final double stdB = Math.sqrt(variances.get(fieldBKey)); - skewness.put(fieldAKey, skewA / ((count - 1) * variances.get(fieldAKey) * stdA)); - skewness.put(fieldBKey, skewB / ((count - 1) * variances.get(fieldBKey) * stdB)); - kurtosis.put(fieldAKey, kurtA / ((count - 1) * variances.get(fieldAKey) * variances.get(fieldAKey))); - kurtosis.put(fieldBKey, kurtB / ((count - 1) * variances.get(fieldBKey) * variances.get(fieldBKey))); - - // compute covariance - final HashMap<String, Double> fieldACovar = new HashMap<>(2); - fieldACovar.put(fieldAKey, 1d); - cVar /= count - 1; - fieldACovar.put(fieldBKey, cVar); - covariances.put(fieldAKey, fieldACovar); - final HashMap<String, Double> fieldBCovar = new HashMap<>(2); - fieldBCovar.put(fieldAKey, cVar); - fieldBCovar.put(fieldBKey, 1d); - covariances.put(fieldBKey, fieldBCovar); - - // compute correlation - final HashMap<String, Double> fieldACorr = new HashMap<>(); - fieldACorr.put(fieldAKey, 1d); - double corr = covariances.get(fieldAKey).get(fieldBKey); - corr /= stdA * stdB; - fieldACorr.put(fieldBKey, corr); - correlations.put(fieldAKey, fieldACorr); - final HashMap<String, Double> fieldBCorr = new HashMap<>(); - fieldBCorr.put(fieldAKey, corr); - fieldBCorr.put(fieldBKey, 1d); - correlations.put(fieldBKey, fieldBCorr); - } - - public void assertNearlyEqual(MatrixStatsResults stats) { - assertThat(count, equalTo(stats.getDocCount())); - assertThat(count, equalTo(stats.getFieldCount(fieldAKey))); - assertThat(count, equalTo(stats.getFieldCount(fieldBKey))); - // means - assertTrue(nearlyEqual(means.get(fieldAKey), stats.getMean(fieldAKey), 1e-7)); - assertTrue(nearlyEqual(means.get(fieldBKey), stats.getMean(fieldBKey), 1e-7)); - // variances - assertTrue(nearlyEqual(variances.get(fieldAKey), stats.getVariance(fieldAKey), 1e-7)); - assertTrue(nearlyEqual(variances.get(fieldBKey), stats.getVariance(fieldBKey), 1e-7)); - // skewness (multi-pass is more susceptible to round-off error so we need to slightly relax the tolerance) - assertTrue(nearlyEqual(skewness.get(fieldAKey), stats.getSkewness(fieldAKey), 1e-4)); - assertTrue(nearlyEqual(skewness.get(fieldBKey), stats.getSkewness(fieldBKey), 1e-4)); - // kurtosis (multi-pass is more susceptible to round-off error so we need to slightly relax the tolerance) - assertTrue(nearlyEqual(kurtosis.get(fieldAKey), stats.getKurtosis(fieldAKey), 1e-4)); - assertTrue(nearlyEqual(kurtosis.get(fieldBKey), stats.getKurtosis(fieldBKey), 1e-4)); - // covariances - assertTrue(nearlyEqual(covariances.get(fieldAKey).get(fieldBKey), stats.getCovariance(fieldAKey, fieldBKey), 1e-7)); - assertTrue(nearlyEqual(covariances.get(fieldBKey).get(fieldAKey), stats.getCovariance(fieldBKey, fieldAKey), 1e-7)); - // correlation - assertTrue(nearlyEqual(correlations.get(fieldAKey).get(fieldBKey), stats.getCorrelation(fieldAKey, fieldBKey), 1e-7)); - assertTrue(nearlyEqual(correlations.get(fieldBKey).get(fieldAKey), stats.getCorrelation(fieldBKey, fieldAKey), 1e-7)); - } - } - - private static boolean nearlyEqual(double a, double b, double epsilon) { - final double absA = Math.abs(a); - final double absB = Math.abs(b); - final double diff = Math.abs(a - b); - - if (a == b) { // shortcut, handles infinities - return true; - } else if (a == 0 || b == 0 || diff < Double.MIN_NORMAL) { - // a or b is zero or both are extremely close to it - // relative error is less meaningful here - return diff < (epsilon * Double.MIN_NORMAL); - } else { // use relative error - return diff / Math.min((absA + absB), Double.MAX_VALUE) < epsilon; - } - } } diff --git a/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStatsTests.java b/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStatsTests.java new file mode 100644 index 0000000000..277006da90 --- /dev/null +++ b/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStatsTests.java @@ -0,0 +1,103 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.search.aggregations.matrix.stats; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.script.ScriptService; +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; +import org.elasticsearch.test.InternalAggregationTestCase; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public class InternalMatrixStatsTests extends InternalAggregationTestCase<InternalMatrixStats> { + + @Override + protected InternalMatrixStats createTestInstance(String name, List<PipelineAggregator> pipelineAggregators, + Map<String, Object> metaData) { + int numFields = randomInt(128); + String[] fieldNames = new String[numFields]; + double[] fieldValues = new double[numFields]; + for (int i = 0; i < numFields; i++) { + fieldNames[i] = Integer.toString(i); + fieldValues[i] = randomDouble(); + } + RunningStats runningStats = new RunningStats(); + runningStats.add(fieldNames, fieldValues); + MatrixStatsResults matrixStatsResults = randomBoolean() ? new MatrixStatsResults(runningStats) : null; + return new InternalMatrixStats("_name", 1L, runningStats, matrixStatsResults, Collections.emptyList(), Collections.emptyMap()); + } + + @Override + protected Writeable.Reader<InternalMatrixStats> instanceReader() { + return InternalMatrixStats::new; + } + + @Override + public void testReduceRandom() { + int numValues = 10000; + int numShards = randomIntBetween(1, 20); + int valuesPerShard = (int) Math.floor(numValues / numShards); + + List<Double> aValues = new ArrayList<>(); + List<Double> bValues = new ArrayList<>(); + + RunningStats runningStats = new RunningStats(); + List<InternalAggregation> shardResults = new ArrayList<>(); + + int valuePerShardCounter = 0; + for (int i = 0; i < numValues; i++) { + double valueA = randomDouble(); + aValues.add(valueA); + double valueB = randomDouble(); + bValues.add(valueB); + + runningStats.add(new String[]{"a", "b"}, new double[]{valueA, valueB}); + if (++valuePerShardCounter == valuesPerShard) { + shardResults.add(new InternalMatrixStats("_name", 1L, runningStats, null, Collections.emptyList(), Collections.emptyMap())); + runningStats = new RunningStats(); + valuePerShardCounter = 0; + } + } + + if (valuePerShardCounter != 0) { + shardResults.add(new InternalMatrixStats("_name", 1L, runningStats, null, Collections.emptyList(), Collections.emptyMap())); + } + MultiPassStats multiPassStats = new MultiPassStats("a", "b"); + multiPassStats.computeStats(aValues, bValues); + + ScriptService mockScriptService = mockScriptService(); + MockBigArrays bigArrays = new MockBigArrays(Settings.EMPTY, new NoneCircuitBreakerService()); + InternalAggregation.ReduceContext context = + new InternalAggregation.ReduceContext(bigArrays, mockScriptService, true); + InternalMatrixStats reduced = (InternalMatrixStats) shardResults.get(0).reduce(shardResults, context); + multiPassStats.assertNearlyEqual(reduced.getResults()); + } + + @Override + protected void assertReduced(InternalMatrixStats reduced, List<InternalMatrixStats> inputs) { + throw new UnsupportedOperationException(); + } +} diff --git a/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/MultiPassStats.java b/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/MultiPassStats.java new file mode 100644 index 0000000000..70e2172ce9 --- /dev/null +++ b/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/MultiPassStats.java @@ -0,0 +1,155 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.search.aggregations.matrix.stats; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +class MultiPassStats { + + private final String fieldAKey; + private final String fieldBKey; + + private long count; + private Map<String, Double> means = new HashMap<>(); + private Map<String, Double> variances = new HashMap<>(); + private Map<String, Double> skewness = new HashMap<>(); + private Map<String, Double> kurtosis = new HashMap<>(); + private Map<String, HashMap<String, Double>> covariances = new HashMap<>(); + private Map<String, HashMap<String, Double>> correlations = new HashMap<>(); + + MultiPassStats(String fieldAName, String fieldBName) { + this.fieldAKey = fieldAName; + this.fieldBKey = fieldBName; + } + + @SuppressWarnings("unchecked") + void computeStats(final List<Double> fieldA, final List<Double> fieldB) { + // set count + count = fieldA.size(); + double meanA = 0d; + double meanB = 0d; + + // compute mean + for (int n = 0; n < count; ++n) { + // fieldA + meanA += fieldA.get(n); + meanB += fieldB.get(n); + } + means.put(fieldAKey, meanA / count); + means.put(fieldBKey, meanB / count); + + // compute variance, skewness, and kurtosis + double dA; + double dB; + double skewA = 0d; + double skewB = 0d; + double kurtA = 0d; + double kurtB = 0d; + double varA = 0d; + double varB = 0d; + double cVar = 0d; + for (int n = 0; n < count; ++n) { + dA = fieldA.get(n) - means.get(fieldAKey); + varA += dA * dA; + skewA += dA * dA * dA; + kurtA += dA * dA * dA * dA; + dB = fieldB.get(n) - means.get(fieldBKey); + varB += dB * dB; + skewB += dB * dB * dB; + kurtB += dB * dB * dB * dB; + cVar += dA * dB; + } + variances.put(fieldAKey, varA / (count - 1)); + final double stdA = Math.sqrt(variances.get(fieldAKey)); + variances.put(fieldBKey, varB / (count - 1)); + final double stdB = Math.sqrt(variances.get(fieldBKey)); + skewness.put(fieldAKey, skewA / ((count - 1) * variances.get(fieldAKey) * stdA)); + skewness.put(fieldBKey, skewB / ((count - 1) * variances.get(fieldBKey) * stdB)); + kurtosis.put(fieldAKey, kurtA / ((count - 1) * variances.get(fieldAKey) * variances.get(fieldAKey))); + kurtosis.put(fieldBKey, kurtB / ((count - 1) * variances.get(fieldBKey) * variances.get(fieldBKey))); + + // compute covariance + final HashMap<String, Double> fieldACovar = new HashMap<>(2); + fieldACovar.put(fieldAKey, 1d); + cVar /= count - 1; + fieldACovar.put(fieldBKey, cVar); + covariances.put(fieldAKey, fieldACovar); + final HashMap<String, Double> fieldBCovar = new HashMap<>(2); + fieldBCovar.put(fieldAKey, cVar); + fieldBCovar.put(fieldBKey, 1d); + covariances.put(fieldBKey, fieldBCovar); + + // compute correlation + final HashMap<String, Double> fieldACorr = new HashMap<>(); + fieldACorr.put(fieldAKey, 1d); + double corr = covariances.get(fieldAKey).get(fieldBKey); + corr /= stdA * stdB; + fieldACorr.put(fieldBKey, corr); + correlations.put(fieldAKey, fieldACorr); + final HashMap<String, Double> fieldBCorr = new HashMap<>(); + fieldBCorr.put(fieldAKey, corr); + fieldBCorr.put(fieldBKey, 1d); + correlations.put(fieldBKey, fieldBCorr); + } + + void assertNearlyEqual(MatrixStatsResults stats) { + assertEquals(count, stats.getDocCount()); + assertEquals(count, stats.getFieldCount(fieldAKey)); + assertEquals(count, stats.getFieldCount(fieldBKey)); + // means + assertTrue(nearlyEqual(means.get(fieldAKey), stats.getMean(fieldAKey), 1e-7)); + assertTrue(nearlyEqual(means.get(fieldBKey), stats.getMean(fieldBKey), 1e-7)); + // variances + assertTrue(nearlyEqual(variances.get(fieldAKey), stats.getVariance(fieldAKey), 1e-7)); + assertTrue(nearlyEqual(variances.get(fieldBKey), stats.getVariance(fieldBKey), 1e-7)); + // skewness (multi-pass is more susceptible to round-off error so we need to slightly relax the tolerance) + assertTrue(nearlyEqual(skewness.get(fieldAKey), stats.getSkewness(fieldAKey), 1e-4)); + assertTrue(nearlyEqual(skewness.get(fieldBKey), stats.getSkewness(fieldBKey), 1e-4)); + // kurtosis (multi-pass is more susceptible to round-off error so we need to slightly relax the tolerance) + assertTrue(nearlyEqual(kurtosis.get(fieldAKey), stats.getKurtosis(fieldAKey), 1e-4)); + assertTrue(nearlyEqual(kurtosis.get(fieldBKey), stats.getKurtosis(fieldBKey), 1e-4)); + // covariances + assertTrue(nearlyEqual(covariances.get(fieldAKey).get(fieldBKey),stats.getCovariance(fieldAKey, fieldBKey), 1e-7)); + assertTrue(nearlyEqual(covariances.get(fieldBKey).get(fieldAKey),stats.getCovariance(fieldBKey, fieldAKey), 1e-7)); + // correlation + assertTrue(nearlyEqual(correlations.get(fieldAKey).get(fieldBKey), stats.getCorrelation(fieldAKey, fieldBKey), 1e-7)); + assertTrue(nearlyEqual(correlations.get(fieldBKey).get(fieldAKey), stats.getCorrelation(fieldBKey, fieldAKey), 1e-7)); + } + + private static boolean nearlyEqual(double a, double b, double epsilon) { + final double absA = Math.abs(a); + final double absB = Math.abs(b); + final double diff = Math.abs(a - b); + + if (a == b) { // shortcut, handles infinities + return true; + } else if (a == 0 || b == 0 || diff < Double.MIN_NORMAL) { + // a or b is zero or both are extremely close to it + // relative error is less meaningful here + return diff < (epsilon * Double.MIN_NORMAL); + } else { // use relative error + return diff / Math.min((absA + absB), Double.MAX_VALUE) < epsilon; + } + } +} |