diff --git a/CHANGELOG.md b/CHANGELOG.md index a76f222901..cd84337ef3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## Unreleased +### Added + +- Added support to allow support for manual keys in add_columns as well. Was discussed in: https://github.com/Textualize/textual/discussions/5922 + ### Fixed - Fixed issue with the "transparent" CSS value not being transparent when set using python https://github.com/Textualize/textual/pull/5890 diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index caf6e760f3..a67750dea8 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -4,7 +4,16 @@ from dataclasses import dataclass from itertools import chain, zip_longest from operator import itemgetter -from typing import Any, Callable, ClassVar, Generic, Iterable, NamedTuple, TypeVar +from typing import ( + Any, + Callable, + ClassVar, + Generic, + Iterable, + NamedTuple, + TypeVar, + Union, +) import rich.repr from rich.console import RenderableType @@ -1716,20 +1725,41 @@ def add_row( self.check_idle() return row_key - def add_columns(self, *labels: TextType) -> list[ColumnKey]: - """Add a number of columns. + def add_columns( + self, *columns: Union[TextType, tuple[TextType, str]] + ) -> list[ColumnKey]: + """Add multiple columns to the DataTable. Args: - *labels: Column headers. + *columns: Column specifications. Each can be either: + - A string or Text object (label only, auto-generated key) + - A tuple of (label, key) for manual key control Returns: A list of the keys for the columns that were added. See the `add_column` method docstring for more information on how these keys are used. + + Examples: + ```python + # Add columns with auto-generated keys + keys = table.add_columns("Name", "Age", "City") + + # Add columns with manual keys + keys = table.add_columns( + ("Name", "name_col"), + ("Age", "age_col"), + "City" # Mixed with auto-generated key + ) + ``` """ column_keys = [] - for label in labels: - column_key = self.add_column(label, width=None) + for column in columns: + if isinstance(column, tuple): + label, key = column + column_key = self.add_column(label, width=None, key=key) + else: + column_key = self.add_column(column, width=None) column_keys.append(column_key) return column_keys diff --git a/tests/test_data_table.py b/tests/test_data_table.py index d78d49e2fa..63032c67b0 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -299,6 +299,25 @@ async def test_add_columns(): assert len(table.columns) == 3 +async def test_add_columns_with_tuples(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + column_keys = table.add_columns( + ("Column 1", "col1"), "Column 2", ("Column 3", "col3") + ) + assert len(column_keys) == 3 + assert len(table.columns) == 3 + + assert column_keys[0] == "col1" + assert column_keys[1] != "col1" + assert column_keys[2] == "col3" + + assert table.columns[column_keys[0]].label.plain == "Column 1" + assert table.columns[column_keys[1]].label.plain == "Column 2" + assert table.columns[column_keys[2]].label.plain == "Column 3" + + async def test_add_columns_user_defined_keys(): app = DataTableApp() async with app.run_test():