d384cbdb91
Resolves: rhbz#2157723
151 lines
7.1 KiB
Diff
151 lines
7.1 KiB
Diff
diff --git a/benchmarks/benchmarks/stats.py b/benchmarks/benchmarks/stats.py
|
|
index 0f105b39..ed399f4d 100644
|
|
--- a/benchmarks/benchmarks/stats.py
|
|
+++ b/benchmarks/benchmarks/stats.py
|
|
@@ -162,6 +162,33 @@ class InferentialStats(Benchmark):
|
|
stats.mstats.kruskal(self.a, self.b)
|
|
|
|
|
|
+# Benchmark data for the truncnorm stats() method.
|
|
+# The data in each row is:
|
|
+# a, b, mean, variance, skewness, excess kurtosis. Generated using
|
|
+# https://gist.github.com/WarrenWeckesser/636b537ee889679227d53543d333a720
|
|
+truncnorm_cases = [[-20, -19, -19.052343945976656, 0.002725073018195613,
|
|
+ -1.9838693623377885, 5.871801893091683],
|
|
+ [-30, -29, -29.034401237736176, 0.0011806604886186853,
|
|
+ -1.9929615171469608, 5.943905539773037],
|
|
+ [-40, -39, -39.02560741993011, 0.0006548827702932775,
|
|
+ -1.9960847672775606, 5.968744357649675],
|
|
+ [39, 40, 39.02560741993011, 0.0006548827702932775,
|
|
+ 1.9960847672775606, 5.968744357649675]]
|
|
+truncnorm_cases = np.array(truncnorm_cases)
|
|
+
|
|
+
|
|
+class TruncnormStats(Benchmark):
|
|
+ param_names = ['case', 'moment']
|
|
+ params = [list(range(len(truncnorm_cases))), ['m', 'v', 's', 'k']]
|
|
+
|
|
+ def track_truncnorm_stats_error(self, case, moment):
|
|
+ result_indices = dict(zip(['m', 'v', 's', 'k'], range(2, 6)))
|
|
+ ref = truncnorm_cases[case, result_indices[moment]]
|
|
+ a, b = truncnorm_cases[case, 0:2]
|
|
+ res = stats.truncnorm(a, b).stats(moments=moment)
|
|
+ return np.abs((res - ref)/ref)
|
|
+
|
|
+
|
|
class DistributionsAll(Benchmark):
|
|
# all distributions are in this list. A conversion to a set is used to
|
|
# remove duplicates that appear more than once in either `distcont` or
|
|
diff --git a/scipy/stats/tests/test_distributions.py b/scipy/stats/tests/test_distributions.py
|
|
index d18ad6f3..a15a9301 100644
|
|
--- a/scipy/stats/tests/test_distributions.py
|
|
+++ b/scipy/stats/tests/test_distributions.py
|
|
@@ -910,61 +910,52 @@ class TestTruncnorm:
|
|
assert_almost_equal(s, s0, decimal=decimal_s)
|
|
assert_almost_equal(k, k0)
|
|
|
|
- @pytest.mark.xfail_on_32bit("reduced accuracy with 32bit platforms.")
|
|
- def test_moments(self):
|
|
- # Values validated by changing TRUNCNORM_TAIL_X so as to evaluate
|
|
- # using both the _norm_XXX() and _norm_logXXX() functions, and by
|
|
- # removing the _stats and _munp methods in truncnorm tp force
|
|
- # numerical quadrature.
|
|
- # For m,v,s,k expect k to have the largest error as it is
|
|
- # constructed from powers of lower moments
|
|
-
|
|
- self._test_moments_one_range(-30, 30, [0, 1, 0.0, 0.0])
|
|
- self._test_moments_one_range(-10, 10, [0, 1, 0.0, 0.0])
|
|
- self._test_moments_one_range(-3, 3, [0.0, 0.9733369246625415,
|
|
- 0.0, -0.1711144363977444])
|
|
- self._test_moments_one_range(-2, 2, [0.0, 0.7737413035499232,
|
|
- 0.0, -0.6344632828703505])
|
|
-
|
|
- self._test_moments_one_range(0, np.inf, [0.7978845608028654,
|
|
- 0.3633802276324186,
|
|
- 0.9952717464311565,
|
|
- 0.8691773036059725])
|
|
- self._test_moments_one_range(-np.inf, 0, [-0.7978845608028654,
|
|
- 0.3633802276324186,
|
|
- -0.9952717464311565,
|
|
- 0.8691773036059725])
|
|
-
|
|
- self._test_moments_one_range(-1, 3, [0.2827861107271540,
|
|
- 0.6161417353578292,
|
|
- 0.5393018494027878,
|
|
- -0.2058206513527461])
|
|
- self._test_moments_one_range(-3, 1, [-0.2827861107271540,
|
|
- 0.6161417353578292,
|
|
- -0.5393018494027878,
|
|
- -0.2058206513527461])
|
|
-
|
|
- self._test_moments_one_range(-10, -9, [-9.1084562880124764,
|
|
- 0.0114488058210104,
|
|
- -1.8985607337519652,
|
|
- 5.0733457094223553])
|
|
- self._test_moments_one_range(-20, -19, [-19.0523439459766628,
|
|
- 0.0027250730180314,
|
|
- -1.9838694022629291,
|
|
- 5.8717850028287586])
|
|
- self._test_moments_one_range(-30, -29, [-29.0344012377394698,
|
|
- 0.0011806603928891,
|
|
- -1.9930304534611458,
|
|
- 5.8854062968996566],
|
|
- decimal_s=6)
|
|
- self._test_moments_one_range(-40, -39, [-39.0256074199326264,
|
|
- 0.0006548826719649,
|
|
- -1.9963146354109957,
|
|
- 5.6167758371700494])
|
|
- self._test_moments_one_range(39, 40, [39.0256074199326264,
|
|
- 0.0006548826719649,
|
|
- 1.9963146354109957,
|
|
- 5.6167758371700494])
|
|
+ # Test data for the truncnorm stats() method.
|
|
+ # The data in each row is:
|
|
+ # a, b, mean, variance, skewness, excess kurtosis. Generated using
|
|
+ # https://gist.github.com/WarrenWeckesser/636b537ee889679227d53543d333a720
|
|
+ _truncnorm_stats_data = [
|
|
+ [-30, 30,
|
|
+ 0.0, 1.0, 0.0, 0.0],
|
|
+ [-10, 10,
|
|
+ 0.0, 1.0, 0.0, -1.4927521335810455e-19],
|
|
+ [-3, 3,
|
|
+ 0.0, 0.9733369246625415, 0.0, -0.17111443639774404],
|
|
+ [-2, 2,
|
|
+ 0.0, 0.7737413035499232, 0.0, -0.6344632828703505],
|
|
+ [0, np.inf,
|
|
+ 0.7978845608028654,
|
|
+ 0.3633802276324187,
|
|
+ 0.995271746431156,
|
|
+ 0.8691773036059741],
|
|
+ [-np.inf, 0,
|
|
+ -0.7978845608028654,
|
|
+ 0.3633802276324187,
|
|
+ -0.995271746431156,
|
|
+ 0.8691773036059741],
|
|
+ [-1, 3,
|
|
+ 0.282786110727154,
|
|
+ 0.6161417353578293,
|
|
+ 0.5393018494027877,
|
|
+ -0.20582065135274694],
|
|
+ [-3, 1,
|
|
+ -0.282786110727154,
|
|
+ 0.6161417353578293,
|
|
+ -0.5393018494027877,
|
|
+ -0.20582065135274694],
|
|
+ [-10, -9,
|
|
+ -9.108456288012409,
|
|
+ 0.011448805821636248,
|
|
+ -1.8985607290949496,
|
|
+ 5.0733461105025075],
|
|
+ ]
|
|
+ _truncnorm_stats_data = np.array(_truncnorm_stats_data)
|
|
+
|
|
+ @pytest.mark.parametrize("case", _truncnorm_stats_data)
|
|
+ def test_moments(self, case):
|
|
+ a, b, m0, v0, s0, k0 = case
|
|
+ m, v, s, k = stats.truncnorm.stats(a, b, moments='mvsk')
|
|
+ assert_allclose([m, v, s, k], [m0, v0, s0, k0], atol=1e-17)
|
|
|
|
def test_9902_moments(self):
|
|
m, v = stats.truncnorm.stats(0, np.inf, moments='mv')
|