Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_plot_period_transactions_mbgf(self, cd_data):
mbgf = ModifiedBetaGeoFitter()
mbgf.fit(cd_data['frequency'], cd_data['recency'], cd_data['T'], iterative_fitting=1)
ax = plotting.plot_period_transactions(mbgf)
assert_equal(ax.title.get_text(), "Frequency of Repeat Transactions")
assert_equal(ax.xaxis.get_label().get_text(), "Number of Calibration Period Transactions")
assert_equal(ax.yaxis.get_label().get_text(), "Customers")
assert_array_equal([label.get_text() for label in ax.legend_.get_texts()], ["Actual", "Model"])
plt.close()
def test_plot_period_transactions(self, bgf):
expected = [1411, 439, 214, 100, 62, 38, 29, 1411, 439, 214, 100, 62, 38, 29]
ax = plotting.plot_period_transactions(bgf)
assert_allclose([p.get_height() for p in ax.patches], expected, rtol=0.3)
assert_equal(ax.title.get_text(), "Frequency of Repeat Transactions")
assert_equal(ax.xaxis.get_label().get_text(), "Number of Calibration Period Transactions")
assert_equal(ax.yaxis.get_label().get_text(), "Customers")
assert_array_equal([label.get_text() for label in ax.legend_.get_texts()], ["Actual", "Model"])
plt.close()
def test_plot_period_transactions_max_frequency(self, bgf):
expected = [1411, 439, 214, 100, 62, 38, 29, 23, 7, 5, 5, 5,
1429, 470, 155, 89, 71, 39, 26, 20, 18, 9, 6, 7]
ax = plotting.plot_period_transactions(bgf, max_frequency=12)
assert_allclose([p.get_height() for p in ax.patches], expected, atol=50) # can be large relative differences for small counts
assert_equal(ax.title.get_text(), "Frequency of Repeat Transactions")
assert_equal(ax.xaxis.get_label().get_text(), "Number of Calibration Period Transactions")
assert_equal(ax.yaxis.get_label().get_text(), "Customers")
assert_array_equal([label.get_text() for label in ax.legend_.get_texts()], ["Actual", "Model"])
plt.close()
def test_plot_period_transactions_labels(self, bgf):
expected = [1411, 439, 214, 100, 62, 38, 29, 1411, 439, 214, 100, 62, 38, 29]
ax = plotting.plot_period_transactions(bgf, label=['A', 'B'])
assert_allclose([p.get_height() for p in ax.patches], expected, rtol=0.3)
assert_equal(ax.title.get_text(), "Frequency of Repeat Transactions")
assert_equal(ax.xaxis.get_label().get_text(), "Number of Calibration Period Transactions")
assert_equal(ax.yaxis.get_label().get_text(), "Customers")
assert_array_equal([label.get_text() for label in ax.legend_.get_texts()], ["A", "B"])
plt.close()