Skip to content

Commit

Permalink
weibull_aft arguments update
Browse files Browse the repository at this point in the history
  • Loading branch information
aGuyLearning committed Dec 18, 2024
1 parent 35dbacf commit 22d190a
Showing 1 changed file with 77 additions and 25 deletions.
102 changes: 77 additions & 25 deletions ehrapy/tools/_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,26 @@ def _regression_model_data_frame_preparation(adata: AnnData, duration_col: str,
return df


def _regression_model_populate_adata(adata: AnnData, model_summary: pd.DataFrame, key_added_prefix: str = None):
if key_added_prefix is None:
key_added_prefix = ""
else:
key_added_prefix = key_added_prefix + "_"

full_results = pd.DataFrame(index=adata.var.index)

# Populate with CoxPH summary data
for key in model_summary.columns:
full_results[key_added_prefix + key] = model_summary[key]

# Add a boolean column indicating rows populated by this function
full_results[key_added_prefix + "cox_ph_populated"] = full_results.notna().any(axis=1)

# Assign results back to adata.var
for col in full_results.columns:
adata.var[col] = full_results[col]


def cox_ph(
adata: AnnData,
duration_col: str,
Expand Down Expand Up @@ -397,6 +417,7 @@ def cox_ph(
event_col: The name of the column in anndata that contains the subjects’ death observation.
If left as None, assume all individuals are uncensored.
inplace: Whether to modify the AnnData object in place.
key_added_prefix: Prefix to add to the column names in the AnnData object. An underscore will be added between the prefix and the column
alpha: The alpha value in the confidence intervals.
label: A string to name the column of the estimate.
baseline_estimation_method: The method used to estimate the baseline hazard. Options are 'breslow', 'spline', and 'piecewise'.
Expand Down Expand Up @@ -456,26 +477,7 @@ def cox_ph(

# Add the results to the AnnData object
if inplace:
if key_added_prefix is None:
key_added_prefix = ""
else:
key_added_prefix = key_added_prefix + "_"

cox_ph_summary = cox_ph.summary
print(cox_ph_summary)

full_results = pd.DataFrame(index=adata.var.index)

# Populate with CoxPH summary data
for key in cox_ph_summary.columns:
full_results[key_added_prefix + key] = cox_ph_summary[key]

# Add a boolean column indicating rows populated by this function
full_results[key_added_prefix + "cox_ph_populated"] = full_results.notna().any(axis=1)

# Assign results back to adata.var
for col in full_results.columns:
adata.var[col] = full_results[col]
_regression_model_populate_adata(adata, cox_ph.summary, key_added_prefix)

return cox_ph

Expand All @@ -492,7 +494,7 @@ def weibull_aft(
l1_ratio: float = 0.0,
model_ancillary: bool = True,
event_col: str | None = None,
ancillary: bool | pd.DataFrame | None = None,
ancillary: bool | pd.DataFrame | str | None = None,
show_progress: bool = False,
weights_col: str | None = None,
robust: bool = False,
Expand All @@ -512,22 +514,71 @@ def weibull_aft(
Args:
adata: AnnData object with necessary columns `duration_col` and `event_col`.
duration_col: Name of the column in the AnnData objects that contains the subjects’ lifetimes.
event_col: Name of the column in anndata that contains the subjects’ death observation.
inplace: Whether to modify the AnnData object in place.
key_added_prefix: Prefix to add to the column names in the AnnData object. An underscore will be added between the prefix and the column name.
alpha: The alpha value in the confidence intervals.
fit_intercept: Whether to fit an intercept term in the model.
penalizer: Attach a penalty to the size of the coefficients during regression. This improves stability of the estimates and controls for high correlation between covariates.
l1_ratio: Specify what ratio to assign to a L1 vs L2 penalty. Same as scikit-learn. See penalizer above.
model_ancillary: set the model instance to always model the ancillary parameter with the supplied Dataframe. This is useful for grid-search optimization.
event_col: Name of the column in anndata that contains the subjects’ death observation. 1 if observed, 0 else (censored).
If left as None, assume all individuals are uncensored.
ancillary: Choose to model the ancillary parameters.
If None or False, explicitly do not fit the ancillary parameters using any covariates.
If True, model the ancillary parameters with the same covariates as ``df``.
If DataFrame, provide covariates to model the ancillary parameters. Must be the same row count as ``df``.
If str, should be a formula
show_progress: since the fitter is iterative, show convergence diagnostics. Useful if convergence is failing.
weights_col: The name of the column in DataFrame that contains the weights for each subject.
robust: Compute the robust errors using the Huber sandwich estimator, aka Wei-Lin estimate. This does not handle ties, so if there are high number of ties, results may significantly differ.
initial_point: set the starting point for the iterative solver.
entry_col: Column denoting when a subject entered the study, i.e. left-truncation.
formula: Use an R-style formula for modeling the dataset. See formula syntax: https://matthewwardrop.github.io/formulaic/basic/grammar/
If a formula is not provided, all variables in the dataframe are used (minus those used for other purposes like event_col, etc.)
fit_options: Additional keyword arguments to pass into the estimator.
Returns:
Fitted WeibullAFTFitter.
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=False)
>>> # Flip 'censor_fl' because 0 = death and 1 = censored
>>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
>>> aft = ep.tl.weibull_aft(adata, "mort_day_censored", "censor_flg")
>>> adata = adata[:, ["mort_day_censored", "censor_flg"]]
>>> aft = ep.tl.weibull_aft(adata, duration_col="mort_day_censored", event_col="censor_flg")
>>> aft.print_summary()
"""

return _regression_model(WeibullAFTFitter, adata, duration_col, event_col, entry_col, accept_zero_duration=False)
df = _regression_model_data_frame_preparation(adata, duration_col, accept_zero_duration=False)

weibull_aft = WeibullAFTFitter(
alpha=alpha,
fit_intercept=fit_intercept,
penalizer=penalizer,
l1_ratio=l1_ratio,
model_ancillary=model_ancillary,
)

weibull_aft.fit(
df,
duration_col=duration_col,
event_col=event_col,
entry_col=entry_col,
ancillary=ancillary,
show_progress=show_progress,
weights_col=weights_col,
robust=robust,
initial_point=initial_point,
formula=formula,
fit_options=fit_options,
)

# Add the results to the AnnData object
if inplace:
_regression_model_populate_adata(adata, weibull_aft.summary, key_added_prefix)

return weibull_aft


def log_logistic_aft(
Expand Down Expand Up @@ -566,6 +617,7 @@ def log_logistic_aft(
>>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
>>> llf = ep.tl.log_logistic_aft(adata, "mort_day_censored", "censor_flg")
"""

return _regression_model(
LogLogisticAFTFitter, adata, duration_col, event_col, entry_col, accept_zero_duration=False
)
Expand Down

0 comments on commit 22d190a

Please sign in to comment.