diff --git a/st_aggrid/__init__.py b/st_aggrid/__init__.py index b96097d..f64cc97 100644 --- a/st_aggrid/__init__.py +++ b/st_aggrid/__init__.py @@ -41,6 +41,7 @@ def AgGrid( reload_data:bool=False, theme:str='light', custom_css=None, + default_column_filters: typing.Optional[str] = None, key: typing.Any=None, **default_column_parameters) -> typing.Dict: """Reders a DataFrame using AgGrid. @@ -126,6 +127,10 @@ def AgGrid( custom_css (dict, optional): Custom CSS rules to be added to the component's iframe. + default_column_filters : str, optional + Default column filters to apply on instantiation. Requires enterprise license. + Defaults to None. + key : typing.Any, optional Streamlits key argument. Check streamlit's documentation. Defaults to None. @@ -146,6 +151,7 @@ def AgGrid( response = {} response["data"] = dataframe response["selected_rows"] = [] + response["column_filters"] = None #basic numpy types of dataframe frame_dtypes = dict(zip(dataframe.columns, (t.kind for t in dataframe.dtypes))) @@ -225,7 +231,8 @@ def cast_to_serializable(value): reload_data=reload_data, theme=theme, custom_css=custom_css, - key=key + default_column_filters=default_column_filters, + key=key, ) except components.components.MarshallComponentException as ex: @@ -250,7 +257,7 @@ def cast_to_serializable(value): text_columns = [k for k,v in original_types.items() if v in ['O','S','U']] if text_columns: - frame.loc[:,text_columns.keys()] = frame.loc[:,text_columns.keys()].astype(str) + frame.loc[:,text_columns] = frame.loc[:,text_columns].astype(str) date_columns = [k for k,v in original_types.items() if v == "M"] if date_columns: @@ -268,5 +275,6 @@ def cast_to_timedelta(s): response["data"] = frame response["selected_rows"] = component_value["selectedRows"] + response["column_filters"] = component_value["columnFilters"] return response diff --git a/st_aggrid/frontend/src/AgGrid.tsx b/st_aggrid/frontend/src/AgGrid.tsx index 9252118..2196eff 100644 --- a/st_aggrid/frontend/src/AgGrid.tsx +++ b/st_aggrid/frontend/src/AgGrid.tsx @@ -68,6 +68,7 @@ class AgGrid extends StreamlitComponentBase { private allowUnsafeJsCode: boolean = false private fitColumnsOnGridLoad: boolean = false private gridOptions: any + private defaultColumnFilters: string constructor(props: any) { super(props) @@ -89,6 +90,7 @@ class AgGrid extends StreamlitComponentBase { this.manualUpdateRequested = (this.props.args.update_mode === 1) this.allowUnsafeJsCode = this.props.args.allow_unsafe_jscode this.fitColumnsOnGridLoad = this.props.args.fit_columns_on_grid_load + this.defaultColumnFilters = this.props.args.default_column_filters this.columnFormaters = { columnTypes: { @@ -203,7 +205,13 @@ class AgGrid extends StreamlitComponentBase { this.columnApi = event.columnApi this.setUpdateMode() - this.api.addEventListener('firstDataRendered', (e: any) => this.fitColumns()) + this.api.addEventListener('firstDataRendered', (e: any) => { + this.fitColumns() + if (this.defaultColumnFilters) { + let filters = JSON.parse(this.defaultColumnFilters) + this.api.setFilterModel(filters) + } + }) this.api.setRowData(this.state.rowData) @@ -270,7 +278,8 @@ class AgGrid extends StreamlitComponentBase { let returnValue = { originalDtypes: this.frameDtypes, rowData: returnData, - selectedRows: this.api.getSelectedRows() + selectedRows: this.api.getSelectedRows(), + columnFilters: JSON.stringify(this.api.getFilterModel()) } Streamlit.setComponentValue(returnValue)