From 22d190a19c07aa0b8fab49cc5bfa36e33f4bacce Mon Sep 17 00:00:00 2001 From: Carl Buchholz <32228189+aGuyLearning@users.noreply.github.com> Date: Wed, 18 Dec 2024 14:47:09 +0100 Subject: [PATCH] weibull_aft arguments update --- ehrapy/tools/_sa.py | 102 +++++++++++++++++++++++++++++++++----------- 1 file changed, 77 insertions(+), 25 deletions(-) diff --git a/ehrapy/tools/_sa.py b/ehrapy/tools/_sa.py index 241e5dee..13b74a8f 100644 --- a/ehrapy/tools/_sa.py +++ b/ehrapy/tools/_sa.py @@ -358,6 +358,26 @@ def _regression_model_data_frame_preparation(adata: AnnData, duration_col: str, return df +def _regression_model_populate_adata(adata: AnnData, model_summary: pd.DataFrame, key_added_prefix: str = None): + if key_added_prefix is None: + key_added_prefix = "" + else: + key_added_prefix = key_added_prefix + "_" + + full_results = pd.DataFrame(index=adata.var.index) + + # Populate with CoxPH summary data + for key in model_summary.columns: + full_results[key_added_prefix + key] = model_summary[key] + + # Add a boolean column indicating rows populated by this function + full_results[key_added_prefix + "cox_ph_populated"] = full_results.notna().any(axis=1) + + # Assign results back to adata.var + for col in full_results.columns: + adata.var[col] = full_results[col] + + def cox_ph( adata: AnnData, duration_col: str, @@ -397,6 +417,7 @@ def cox_ph( event_col: The name of the column in anndata that contains the subjects’ death observation. If left as None, assume all individuals are uncensored. inplace: Whether to modify the AnnData object in place. + key_added_prefix: Prefix to add to the column names in the AnnData object. An underscore will be added between the prefix and the column alpha: The alpha value in the confidence intervals. label: A string to name the column of the estimate. baseline_estimation_method: The method used to estimate the baseline hazard. Options are 'breslow', 'spline', and 'piecewise'. @@ -456,26 +477,7 @@ def cox_ph( # Add the results to the AnnData object if inplace: - if key_added_prefix is None: - key_added_prefix = "" - else: - key_added_prefix = key_added_prefix + "_" - - cox_ph_summary = cox_ph.summary - print(cox_ph_summary) - - full_results = pd.DataFrame(index=adata.var.index) - - # Populate with CoxPH summary data - for key in cox_ph_summary.columns: - full_results[key_added_prefix + key] = cox_ph_summary[key] - - # Add a boolean column indicating rows populated by this function - full_results[key_added_prefix + "cox_ph_populated"] = full_results.notna().any(axis=1) - - # Assign results back to adata.var - for col in full_results.columns: - adata.var[col] = full_results[col] + _regression_model_populate_adata(adata, cox_ph.summary, key_added_prefix) return cox_ph @@ -492,7 +494,7 @@ def weibull_aft( l1_ratio: float = 0.0, model_ancillary: bool = True, event_col: str | None = None, - ancillary: bool | pd.DataFrame | None = None, + ancillary: bool | pd.DataFrame | str | None = None, show_progress: bool = False, weights_col: str | None = None, robust: bool = False, @@ -512,9 +514,29 @@ def weibull_aft( Args: adata: AnnData object with necessary columns `duration_col` and `event_col`. duration_col: Name of the column in the AnnData objects that contains the subjects’ lifetimes. - event_col: Name of the column in anndata that contains the subjects’ death observation. + inplace: Whether to modify the AnnData object in place. + key_added_prefix: Prefix to add to the column names in the AnnData object. An underscore will be added between the prefix and the column name. + alpha: The alpha value in the confidence intervals. + fit_intercept: Whether to fit an intercept term in the model. + penalizer: Attach a penalty to the size of the coefficients during regression. This improves stability of the estimates and controls for high correlation between covariates. + l1_ratio: Specify what ratio to assign to a L1 vs L2 penalty. Same as scikit-learn. See penalizer above. + model_ancillary: set the model instance to always model the ancillary parameter with the supplied Dataframe. This is useful for grid-search optimization. + event_col: Name of the column in anndata that contains the subjects’ death observation. 1 if observed, 0 else (censored). If left as None, assume all individuals are uncensored. + ancillary: Choose to model the ancillary parameters. + If None or False, explicitly do not fit the ancillary parameters using any covariates. + If True, model the ancillary parameters with the same covariates as ``df``. + If DataFrame, provide covariates to model the ancillary parameters. Must be the same row count as ``df``. + If str, should be a formula + show_progress: since the fitter is iterative, show convergence diagnostics. Useful if convergence is failing. + weights_col: The name of the column in DataFrame that contains the weights for each subject. + robust: Compute the robust errors using the Huber sandwich estimator, aka Wei-Lin estimate. This does not handle ties, so if there are high number of ties, results may significantly differ. + initial_point: set the starting point for the iterative solver. entry_col: Column denoting when a subject entered the study, i.e. left-truncation. + formula: Use an R-style formula for modeling the dataset. See formula syntax: https://matthewwardrop.github.io/formulaic/basic/grammar/ + If a formula is not provided, all variables in the dataframe are used (minus those used for other purposes like event_col, etc.) + fit_options: Additional keyword arguments to pass into the estimator. + Returns: Fitted WeibullAFTFitter. @@ -522,12 +544,41 @@ def weibull_aft( Examples: >>> import ehrapy as ep >>> adata = ep.dt.mimic_2(encoded=False) - >>> # Flip 'censor_fl' because 0 = death and 1 = censored >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0) - >>> aft = ep.tl.weibull_aft(adata, "mort_day_censored", "censor_flg") + >>> adata = adata[:, ["mort_day_censored", "censor_flg"]] + >>> aft = ep.tl.weibull_aft(adata, duration_col="mort_day_censored", event_col="censor_flg") + >>> aft.print_summary() """ - return _regression_model(WeibullAFTFitter, adata, duration_col, event_col, entry_col, accept_zero_duration=False) + df = _regression_model_data_frame_preparation(adata, duration_col, accept_zero_duration=False) + + weibull_aft = WeibullAFTFitter( + alpha=alpha, + fit_intercept=fit_intercept, + penalizer=penalizer, + l1_ratio=l1_ratio, + model_ancillary=model_ancillary, + ) + + weibull_aft.fit( + df, + duration_col=duration_col, + event_col=event_col, + entry_col=entry_col, + ancillary=ancillary, + show_progress=show_progress, + weights_col=weights_col, + robust=robust, + initial_point=initial_point, + formula=formula, + fit_options=fit_options, + ) + + # Add the results to the AnnData object + if inplace: + _regression_model_populate_adata(adata, weibull_aft.summary, key_added_prefix) + + return weibull_aft def log_logistic_aft( @@ -566,6 +617,7 @@ def log_logistic_aft( >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0) >>> llf = ep.tl.log_logistic_aft(adata, "mort_day_censored", "censor_flg") """ + return _regression_model( LogLogisticAFTFitter, adata, duration_col, event_col, entry_col, accept_zero_duration=False )