diff --git a/django_tables2/views.py b/django_tables2/views.py index 68c42eb6..9a33d4c6 100644 --- a/django_tables2/views.py +++ b/django_tables2/views.py @@ -204,22 +204,32 @@ class MultiTableMixin(TableMixinBase): # override context table name to make sense in a multiple table context context_table_name = "tables" - def get_tables(self): + def get_tables_classes(self): """ - Return an array of table instances containing data. + Return the list of classes to use for the tables. """ if self.tables is None: - view_name = type(self).__name__ - raise ImproperlyConfigured(f"No tables were specified. Define {view_name}.tables") + klass = type(self).__name__ + raise ImproperlyConfigured( + f"You must either specify {klass}.tables or override {klass}.get_tables_classes()" + ) + + return self.tables + + def get_tables(self, **kwargs): + """ + Return an array of table instances containing data. + """ + tables = self.get_tables_classes() data = self.get_tables_data() if data is None: - return self.tables + return tables - if len(data) != len(self.tables): - view_name = type(self).__name__ - raise ImproperlyConfigured(f"len({view_name}.tables_data) != len({view_name}.tables)") - return list(Table(data[i]) for i, Table in enumerate(self.tables)) + if len(data) != len(tables): + klass = type(self).__name__ + raise ImproperlyConfigured(f"len({klass}.tables_data) != len({klass}.tables)") + return list(Table(data[i], **kwargs) for i, Table in enumerate(tables)) def get_tables_data(self): """ diff --git a/tests/test_views.py b/tests/test_views.py index 9012bbb8..b7c0e2c9 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -411,15 +411,42 @@ def test_without_tables(self): class View(tables.MultiTableMixin, TemplateView): template_name = "multiple.html" - message = "No tables were specified. Define View.tables" + message = "You must either specify View.tables or override View.get_tables_classes()" with self.assertRaisesMessage(ImproperlyConfigured, message): View.as_view()(build_request("/")) + def test_get_tables_classes_list(self): + class View(tables.MultiTableMixin, TemplateView): + tables_data = (Person.objects.all(), Region.objects.all()) + template_name = "multiple.html" + + def get_tables_classes(self): + return [TableA, TableB] + + response = View.as_view()(build_request("/")) + response.render() + + html = response.rendered_content + self.assertEqual(html.count("