diff --git a/django/website/hid/tests/term_count_widget_tests.py b/django/website/hid/tests/term_count_widget_tests.py index 632435fa67505752cc6f5359516023d408fcbb54..19f32a8367c9a3b566207ecf229a58a64f74c3be 100644 --- a/django/website/hid/tests/term_count_widget_tests.py +++ b/django/website/hid/tests/term_count_widget_tests.py @@ -230,3 +230,41 @@ class TestTermCountChartWidget(TestCase): self.assertEqual(t1, itemcount_kwargs['start_time']) self.assertEqual(t2, itemcount_kwargs['end_time']) + + def test_categories_can_be_excluded(self): + widget = TermCountChartWidget() + + with patch('hid.widgets.term_count_chart.term_itemcount') as itemcount: + itemcount.return_value = [ + { + 'name': 'Ebola updates', + 'long_name': 'What are the current updates on Ebola.', + 'count': 0, + }, + { + 'name': 'Ebola prevention', + 'long_name': 'What measures could be put in place to end Ebola.', + 'count': 4, + }, + { + 'name': 'Liberia Ebola-free', + 'long_name': 'Can Liberia be Ebola free.', + 'count': 3, + }, + { + 'name': 'Unknown', + 'long_name': 'Unknown.', + 'count': 2, + }, + ] + context_data = widget.get_context_data( + title='test-name', taxonomy='tax', + exclude_categories=['Unknown', 'Liberia Ebola-free'] + ) + ticks = context_data['options']['yaxis']['ticks'] + labels = [t[1] for t in ticks] + self.assertIn('What are the current updates on Ebola.', labels) + self.assertIn('What measures could be put in place to end Ebola.', + labels) + self.assertNotIn('Can Liberia be Ebola free.', labels) + self.assertNotIn('Unknown.', labels) diff --git a/django/website/hid/widgets/term_count_chart.py b/django/website/hid/widgets/term_count_chart.py index 864e9972b844464d335cb15a82703cde293f0ebe..6098b24409fc929f2b6dfca2cb12da9a841560a4 100644 --- a/django/website/hid/widgets/term_count_chart.py +++ b/django/website/hid/widgets/term_count_chart.py @@ -28,7 +28,8 @@ class TermCountChartWidget(object): 'hid/widgets/chart.js' ] - def _fetch_counts(self, taxonomy, count, start, end, other_label): + def _fetch_counts(self, taxonomy, count, start, end, other_label, + exclude_categories=None): """ Given a taxonomy, fetch the count per term. Args: @@ -49,6 +50,11 @@ class TermCountChartWidget(object): ) else: itemcount = term_itemcount(taxonomy) + + if exclude_categories is not None: + itemcount = [t for t in itemcount + if t['name'] not in exclude_categories] + itemcount.sort(key=lambda k: int(k['count']), reverse=True) if count > 0: head = itemcount[0:count-1] @@ -129,6 +135,7 @@ class TermCountChartWidget(object): count = kwargs.get('count', 0) other_label = kwargs.get('other_label', 'Others') periods = kwargs.get('periods', []) + exclude_categories = kwargs.get('exclude_categories') if len(periods) > 1: raise WidgetError('Only one time period is currently supported') @@ -153,7 +160,8 @@ class TermCountChartWidget(object): label = '' counts = self._fetch_counts( - taxonomy, count, start_time, end_time, other_label + taxonomy, count, start_time, end_time, other_label, + exclude_categories ) (values, yticks) = self._create_axis_values(counts) return {