python3.11-scipy/scipy-1.8.1-test_moments-be...

151 lines
7.1 KiB
Diff

diff --git a/benchmarks/benchmarks/stats.py b/benchmarks/benchmarks/stats.py
index 0f105b39..ed399f4d 100644
--- a/benchmarks/benchmarks/stats.py
+++ b/benchmarks/benchmarks/stats.py
@@ -162,6 +162,33 @@ class InferentialStats(Benchmark):
stats.mstats.kruskal(self.a, self.b)
+# Benchmark data for the truncnorm stats() method.
+# The data in each row is:
+# a, b, mean, variance, skewness, excess kurtosis. Generated using
+# https://gist.github.com/WarrenWeckesser/636b537ee889679227d53543d333a720
+truncnorm_cases = [[-20, -19, -19.052343945976656, 0.002725073018195613,
+ -1.9838693623377885, 5.871801893091683],
+ [-30, -29, -29.034401237736176, 0.0011806604886186853,
+ -1.9929615171469608, 5.943905539773037],
+ [-40, -39, -39.02560741993011, 0.0006548827702932775,
+ -1.9960847672775606, 5.968744357649675],
+ [39, 40, 39.02560741993011, 0.0006548827702932775,
+ 1.9960847672775606, 5.968744357649675]]
+truncnorm_cases = np.array(truncnorm_cases)
+
+
+class TruncnormStats(Benchmark):
+ param_names = ['case', 'moment']
+ params = [list(range(len(truncnorm_cases))), ['m', 'v', 's', 'k']]
+
+ def track_truncnorm_stats_error(self, case, moment):
+ result_indices = dict(zip(['m', 'v', 's', 'k'], range(2, 6)))
+ ref = truncnorm_cases[case, result_indices[moment]]
+ a, b = truncnorm_cases[case, 0:2]
+ res = stats.truncnorm(a, b).stats(moments=moment)
+ return np.abs((res - ref)/ref)
+
+
class DistributionsAll(Benchmark):
# all distributions are in this list. A conversion to a set is used to
# remove duplicates that appear more than once in either `distcont` or
diff --git a/scipy/stats/tests/test_distributions.py b/scipy/stats/tests/test_distributions.py
index d18ad6f3..a15a9301 100644
--- a/scipy/stats/tests/test_distributions.py
+++ b/scipy/stats/tests/test_distributions.py
@@ -910,61 +910,52 @@ class TestTruncnorm:
assert_almost_equal(s, s0, decimal=decimal_s)
assert_almost_equal(k, k0)
- @pytest.mark.xfail_on_32bit("reduced accuracy with 32bit platforms.")
- def test_moments(self):
- # Values validated by changing TRUNCNORM_TAIL_X so as to evaluate
- # using both the _norm_XXX() and _norm_logXXX() functions, and by
- # removing the _stats and _munp methods in truncnorm tp force
- # numerical quadrature.
- # For m,v,s,k expect k to have the largest error as it is
- # constructed from powers of lower moments
-
- self._test_moments_one_range(-30, 30, [0, 1, 0.0, 0.0])
- self._test_moments_one_range(-10, 10, [0, 1, 0.0, 0.0])
- self._test_moments_one_range(-3, 3, [0.0, 0.9733369246625415,
- 0.0, -0.1711144363977444])
- self._test_moments_one_range(-2, 2, [0.0, 0.7737413035499232,
- 0.0, -0.6344632828703505])
-
- self._test_moments_one_range(0, np.inf, [0.7978845608028654,
- 0.3633802276324186,
- 0.9952717464311565,
- 0.8691773036059725])
- self._test_moments_one_range(-np.inf, 0, [-0.7978845608028654,
- 0.3633802276324186,
- -0.9952717464311565,
- 0.8691773036059725])
-
- self._test_moments_one_range(-1, 3, [0.2827861107271540,
- 0.6161417353578292,
- 0.5393018494027878,
- -0.2058206513527461])
- self._test_moments_one_range(-3, 1, [-0.2827861107271540,
- 0.6161417353578292,
- -0.5393018494027878,
- -0.2058206513527461])
-
- self._test_moments_one_range(-10, -9, [-9.1084562880124764,
- 0.0114488058210104,
- -1.8985607337519652,
- 5.0733457094223553])
- self._test_moments_one_range(-20, -19, [-19.0523439459766628,
- 0.0027250730180314,
- -1.9838694022629291,
- 5.8717850028287586])
- self._test_moments_one_range(-30, -29, [-29.0344012377394698,
- 0.0011806603928891,
- -1.9930304534611458,
- 5.8854062968996566],
- decimal_s=6)
- self._test_moments_one_range(-40, -39, [-39.0256074199326264,
- 0.0006548826719649,
- -1.9963146354109957,
- 5.6167758371700494])
- self._test_moments_one_range(39, 40, [39.0256074199326264,
- 0.0006548826719649,
- 1.9963146354109957,
- 5.6167758371700494])
+ # Test data for the truncnorm stats() method.
+ # The data in each row is:
+ # a, b, mean, variance, skewness, excess kurtosis. Generated using
+ # https://gist.github.com/WarrenWeckesser/636b537ee889679227d53543d333a720
+ _truncnorm_stats_data = [
+ [-30, 30,
+ 0.0, 1.0, 0.0, 0.0],
+ [-10, 10,
+ 0.0, 1.0, 0.0, -1.4927521335810455e-19],
+ [-3, 3,
+ 0.0, 0.9733369246625415, 0.0, -0.17111443639774404],
+ [-2, 2,
+ 0.0, 0.7737413035499232, 0.0, -0.6344632828703505],
+ [0, np.inf,
+ 0.7978845608028654,
+ 0.3633802276324187,
+ 0.995271746431156,
+ 0.8691773036059741],
+ [-np.inf, 0,
+ -0.7978845608028654,
+ 0.3633802276324187,
+ -0.995271746431156,
+ 0.8691773036059741],
+ [-1, 3,
+ 0.282786110727154,
+ 0.6161417353578293,
+ 0.5393018494027877,
+ -0.20582065135274694],
+ [-3, 1,
+ -0.282786110727154,
+ 0.6161417353578293,
+ -0.5393018494027877,
+ -0.20582065135274694],
+ [-10, -9,
+ -9.108456288012409,
+ 0.011448805821636248,
+ -1.8985607290949496,
+ 5.0733461105025075],
+ ]
+ _truncnorm_stats_data = np.array(_truncnorm_stats_data)
+
+ @pytest.mark.parametrize("case", _truncnorm_stats_data)
+ def test_moments(self, case):
+ a, b, m0, v0, s0, k0 = case
+ m, v, s, k = stats.truncnorm.stats(a, b, moments='mvsk')
+ assert_allclose([m, v, s, k], [m0, v0, s0, k0], atol=1e-17)
def test_9902_moments(self):
m, v = stats.truncnorm.stats(0, np.inf, moments='mv')