-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
848 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
Zivich PN (August 25, 2022). "Why I use Python (and Why You Should Too)" [presentation]. | ||
43rd Annual Conference of the International Society for Clinical Biostatistics, Newcastle upon Tyne, UK |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
##################################################################################################### | ||
# | ||
# Python code for example in | ||
# Re: Using numerical methods to design simulations: revisiting the balancing intercept | ||
# Paul Zivich and Rachael Ross | ||
# | ||
# Objective: solve for intercepts in data generating models to achieve desired marginal distribution | ||
# | ||
##################################################################################################### | ||
|
||
import warnings | ||
import numpy as np | ||
import pandas as pd | ||
from scipy.optimize import root, newton | ||
|
||
np.random.seed(777743) | ||
|
||
# Setup the baseline data | ||
n = 10000000 | ||
d = pd.DataFrame() | ||
d['X'] = np.random.normal(size=n) | ||
d['C'] = 1 | ||
print("E[X]: ", np.mean(d['X'])) | ||
|
||
|
||
######################################## | ||
# Solving for balancing intercept for A | ||
|
||
# Pr(A | X) = logit(\alpha_0 + alpha_coefs[0]*X) | ||
desired_margin_a = 0.45 | ||
alpha_coefs = [0.25] | ||
W = np.asarray(d[['C', 'X']]) | ||
|
||
|
||
def generate_pr_a(intercepts): | ||
"""Function to calculate the probability of A given an intercept | ||
""" | ||
alpha = np.asarray([intercepts[0]] + alpha_coefs) # Takes intercept and puts together with specified coefs | ||
|
||
# Calculating the probability of A given the coefficients | ||
logit_pr_a = np.dot(W, alpha) # log-odds of A | ||
prob_a = 1 / (1 + np.exp(-logit_pr_a)) # converting to probability of A | ||
return prob_a # Function returns array / vector of probabilities | ||
|
||
|
||
def objective_function_a(intercepts): | ||
"""Objective function to use with a root-finding algorithm to solve for the intercept that provides the desired | ||
marginal probability of A | ||
""" | ||
prob_a = generate_pr_a(intercepts=intercepts) # Calculate probability of A for given intercept | ||
marginal_pr_a = np.mean(prob_a) # Calculate the marginal probability of A | ||
difference_from_desired = marginal_pr_a - desired_margin_a # Calculate difference between current and desired marg | ||
return difference_from_desired # Return the current difference for the intercept | ||
|
||
|
||
# Root-finding procedure for Pr(A) | ||
root_a = newton(objective_function_a, | ||
x0=np.asarray([0.]), | ||
tol=1e-12, maxiter=1000) | ||
|
||
# Examining results | ||
print("alpha_0: ", root_a) | ||
print("Pr(A=1): ", np.mean(generate_pr_a(root_a))) | ||
|
||
######################################## | ||
# Solving for balancing intercept for M | ||
|
||
# where model is Pr(M=1 | X) / Pr(M=0 | X) = ln(\beta_10 + beta_coefs[0][0]*A + beta_coefs[0][1]*X) | ||
# Pr(M=2 | X) / Pr(M=0 | X) = ln(\beta_20 + beta_coefs[1][0]*A + beta_coefs[1][1]*X) | ||
desired_margin_m = np.array([0.5, 0.35, 0.15]) # Desired margins | ||
beta_coefs = [[1.2, -0.15], # Coefficients for M=1 vs. M=0 besides intercept | ||
[0.65, -0.07]] # Coefficients for M=2 vs. M=0 besides intercept | ||
d['A'] = np.random.binomial(n=1, # Generating values of A from model | ||
p=generate_pr_a(root_a), # Using previously numer. approx. of intercept | ||
size=d.shape[0]) # size is number of obs | ||
V = np.asarray(d[['C', 'A', 'X']]) # Covariates to include in model | ||
|
||
|
||
def generate_pr_m(intercepts): | ||
"""Function to calculate the probability of M for each possible value of M given intercepts | ||
""" | ||
beta_10 = np.asarray([intercepts[0]] + beta_coefs[0]) # Takes intercept and puts together with M=1 specified coefs | ||
beta_20 = np.asarray([intercepts[1]] + beta_coefs[1]) # Takes intercept and puts together with M=2 specified coefs | ||
|
||
# Calculating denominator for probability model | ||
denom = 1 + np.exp(np.dot(V, beta_10)) + np.exp(np.dot(V, beta_20)) | ||
|
||
# Calculating probability of M for each category via multinomial logit model | ||
prob_m = np.array([1 / denom, # Probability of M=0 | ||
np.exp(np.dot(V, beta_10)) / denom, # Probability of M=1 | ||
np.exp(np.dot(V, beta_20)) / denom], # Probability of M=2 | ||
) | ||
|
||
# Extra step to check if probability sums to 1 for each individual | ||
if not np.all(np.sum(prob_m, axis=0).round(7) == 1.): # (rounding to avoid approximation errors) | ||
warnings.warn("Some Pr didn't sum to 1... :(", # Warn user if fails to sum to 1 for any individual | ||
UserWarning) | ||
|
||
return prob_m # Function returns 2D array / vector of probabilities | ||
|
||
|
||
def objective_function_m(intercepts): | ||
"""Objective function to use with a root-finding algorithm to solve for the intercept that provides the desired | ||
marginal probabilities of M | ||
""" | ||
prob_m = generate_pr_m(intercepts=intercepts) # Calculate probability of A for given intercept | ||
marginal_pr_m = np.mean(prob_m, axis=1) # Calculate the marginal probability of M across types | ||
difference_from_desired = marginal_pr_m - desired_margin_m # Calculate difference between current and desired marg | ||
return difference_from_desired[1:] # Return the current difference for all BUT M=0 | ||
|
||
|
||
opt_m = root(objective_function_m, # The objective function | ||
x0=np.asarray([0., 0.]), # Initial starting values for procedure (need 2 intercepts here!) | ||
method='lm', tol=1e-12) # Arguments for root-finding algorithm | ||
|
||
# Examining results | ||
print("beta_0: ", opt_m.x) | ||
print("Pr(M): ", np.mean(generate_pr_m(opt_m.x), axis=1)) | ||
|
||
######################################## | ||
# Solving for balancing intercept for Y | ||
|
||
# where the model is Y = \gamma_0 + gamma_coefs[0]*A + gamma_coefs[1]*(M=1) + gamma_coefs[2]*(M=2) | ||
# + gamma_coefs[3]*X + Normal(0, 3) | ||
desired_margin_y = 10. # Desired margin | ||
gamma_coefs = [-1.55, 0.25, 0.45, 0.25] # Coefficients for Y model besides intercept | ||
|
||
|
||
def random_multinomial(a, p): | ||
"""Quick function to generate random values from input multinomial probabilities | ||
""" | ||
s = p.cumsum(axis=0) | ||
r = np.random.rand(p.shape[1]) | ||
k = (s < r).sum(axis=0) | ||
return np.asarray(a)[k] | ||
|
||
|
||
d['M'] = random_multinomial(a=[0, 1, 2], # Generating values of M from model | ||
p=generate_pr_m(opt_m.x)) # Using previously numer. approx. of intercept | ||
d['M1'] = np.where(d['M'] == 1, 1, 0) # Creating indicator variables (for ease) | ||
d['M2'] = np.where(d['M'] == 2, 1, 0) # Creating indicator variables (for ease) | ||
Z = np.asarray(d[['C', 'A', 'M1', 'M2', 'X']]) # Covariates to include in model | ||
error = np.random.normal(scale=3, size=d.shape[0]) # How error terms are simulated | ||
|
||
|
||
def generate_y(intercepts): | ||
"""Function to calculate the values of Y given an intercept | ||
""" | ||
gamma = np.asarray([intercepts[0]] + gamma_coefs) # Takes intercept and puts together with specified coefs | ||
|
||
# Calculating Y values given the coefficients | ||
y = np.dot(Z, gamma) # notice that we ignore the error term here (since safely ignorable for approx. intercepts) | ||
return y # Function returns array / vector of Y values | ||
|
||
|
||
def objective_function_y(intercepts): | ||
"""Objective function to use with a root-finding algorithm to solve for the intercept that provides the desired | ||
marginal probability of A | ||
""" | ||
val_y = generate_y(intercepts=intercepts) # Calculate probability of A for given intercept | ||
marginal_mu_y = np.mean(val_y) # Calculate the marginal mean of Y | ||
difference_from_desired = marginal_mu_y - desired_margin_y # Calculate difference between current and desired marg | ||
return difference_from_desired # Return the current difference for the intercept | ||
|
||
|
||
# Root-finding procedure for Pr(A) | ||
root_y = newton(objective_function_y, # The objective function | ||
x0=np.asarray([0.]), # Initial starting values for procedure | ||
tol=1e-12, maxiter=1000) # Arguments for root-finding algorithm | ||
|
||
# Examining results | ||
print("gamma_0: ", root_y) | ||
print("E[Y]: ", np.mean(generate_y(root_y) + error)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from matplotlib.patches import Rectangle | ||
|
||
# IMAGE: flow diagram for GAN data simulation | ||
fig, ax = plt.subplots() | ||
|
||
# Random noise section | ||
for x in np.arange(0.01, 0.08, 0.01): | ||
for y in np.arange(0.7, 0.81, 0.01): | ||
shade = np.random.uniform(0, 1, 1) | ||
ax.fill_between([x, x+0.01], [y, y], [y+0.01, y+0.01], | ||
color='k', alpha=shade[0]) | ||
|
||
# Generator | ||
ax.arrow(0.1, 0.75, 0.05, 0, head_width=0.01, color='k') | ||
ax.fill_between([0.18, 0.42], [0.8, 0.8], [0.7, 0.7], color='aqua', alpha=0.2) | ||
ax.text(0.20, 0.73, "Generator", size=16) | ||
|
||
# Output generator | ||
ax.arrow(0.43, 0.75, 0.05, 0, head_width=0.01, color='k') | ||
ax.text(0.52, 0.66, r"$X^*_1$" + "\n" + r"$X^*_2$" + "\n" + r"$X^*_3$", size=14) | ||
|
||
# Real Data | ||
ax.text(0.52, 0.18, r"$X_1$" + "\n" + r"$X_2$" + "\n" + r"$X_3$", size=14) | ||
|
||
# Discriminator | ||
ax.arrow(0.58, 0.72, 0.10, -0.15, head_width=0.01, color='k') | ||
ax.arrow(0.58, 0.27, 0.10, 0.15, head_width=0.01, color='k') | ||
ax.fill_between([0.55, 0.83], [0.55, 0.55], [0.45, 0.45], color='orange', alpha=0.2) | ||
ax.text(0.57, 0.48, "Discriminator", size=16) | ||
ax.arrow(0.85, 0.5, 0.05, 0, head_width=0.01, color='k') | ||
ax.text(0.93, 0.485, "T/F", size=14) | ||
|
||
ax.set_ylim([-0., 1.]) | ||
ax.set_xlim([-0., 1.]) | ||
plt.axis('off') | ||
plt.tight_layout() | ||
plt.savefig("images/gan_flow.png", dpi=600, format='png') | ||
plt.close() | ||
|
||
# IMAGE: flow diagram for RNN text generation | ||
fig, ax = plt.subplots() | ||
|
||
# Step 1: PubMed Query | ||
ax.text(0.02, 0.95, "1: Query PubMed") | ||
rectangle = Rectangle((0., 0.62), 0.95, 0.38, alpha=0.2, color='aqua') | ||
ax.add_patch(rectangle) | ||
ax.text(0.11, 0.86, "1a: Conduct search & extract PubMed IDs") | ||
rectangle = Rectangle((0.1, 0.92), 0.8, -0.08, alpha=0.2, color='blue') | ||
ax.add_patch(rectangle) | ||
ax.text(0.11, 0.76, "1b: Select random sample") | ||
rectangle = Rectangle((0.1, 0.82), 0.8, -0.08, alpha=0.2, color='blue') | ||
ax.add_patch(rectangle) | ||
ax.text(0.11, 0.66, "1c: Pull meta-data from PubMed") | ||
rectangle = Rectangle((0.1, 0.72), 0.8, -0.08, alpha=0.2, color='blue') | ||
ax.add_patch(rectangle) | ||
|
||
# Step 2: text processing | ||
ax.text(0.02, 0.55, "2: Text processing") | ||
rectangle = Rectangle((0., 0.32), 0.95, 0.28, alpha=0.2, color='orange') | ||
ax.add_patch(rectangle) | ||
ax.text(0.11, 0.47, "2a: Extract abstracts") | ||
rectangle = Rectangle((0.1, 0.53), 0.8, -0.08, alpha=0.2, color='red') | ||
ax.add_patch(rectangle) | ||
ax.text(0.11, 0.37, "2b: Format text to training data") | ||
rectangle = Rectangle((0.1, 0.43), 0.8, -0.08, alpha=0.2, color='red') | ||
ax.add_patch(rectangle) | ||
|
||
# Step 3: train network | ||
ax.text(0.02, 0.24, "3: RNN") | ||
rectangle = Rectangle((0., -0.4), 0.95, 0.7, alpha=0.2, color='green', zorder=1) | ||
ax.text(0.18, -0.1, "Input: \nsent") | ||
ax.text(0.775, -0.1, "Output: \nsentence") | ||
|
||
ax.add_patch(rectangle) | ||
ax.text(0.275, 0.14, "sent") | ||
ax.text(0.292, -0.33, "e") | ||
ax.arrow(0.3, 0.12, 0, -0.05, head_width=0.01, color='k') | ||
ax.arrow(0.3, -0.21, 0, -0.05, head_width=0.01, color='k') | ||
ax.scatter([0.30, 0.30, 0.30, 0.30], [-0.15, -0.10, -0.05, 0.], | ||
marker='o', c='white', edgecolors='k', zorder=2) | ||
rectangle = Rectangle((0.29, -0.2), 0.02, 0.25, alpha=1, facecolor='none', edgecolor='k') | ||
ax.add_patch(rectangle) | ||
ax.arrow(0.315, -0.28, 0.09, 0.4, head_width=0.01, color='k') | ||
|
||
ax.text(0.415, 0.14, "ente") | ||
ax.text(0.431, -0.33, "n") | ||
ax.arrow(0.44, 0.12, 0, -0.05, head_width=0.01, color='k') | ||
ax.arrow(0.44, -0.21, 0, -0.05, head_width=0.01, color='k') | ||
ax.scatter([0.44, 0.44, 0.44, 0.44], [-0.15, -0.10, -0.05, 0.], | ||
marker='o', c='white', edgecolors='k', zorder=2) | ||
rectangle = Rectangle((0.43, -0.2), 0.02, 0.25, alpha=1, facecolor='none', edgecolor='k') | ||
ax.add_patch(rectangle) | ||
ax.arrow(0.45, -0.28, 0.09, 0.4, head_width=0.01, color='k') | ||
|
||
ax.text(0.55, 0.14, "nten") | ||
ax.text(0.563, -0.33, "c") | ||
ax.arrow(0.57, 0.12, 0, -0.05, head_width=0.01, color='k') | ||
ax.arrow(0.57, -0.21, 0, -0.05, head_width=0.01, color='k') | ||
ax.scatter([0.57, 0.57, 0.57, 0.57], [-0.15, -0.10, -0.05, 0.], | ||
marker='o', c='white', edgecolors='k', zorder=2) | ||
rectangle = Rectangle((0.5595, -0.2), 0.02, 0.25, alpha=1, facecolor='none', edgecolor='k') | ||
ax.add_patch(rectangle) | ||
ax.arrow(0.58, -0.28, 0.09, 0.4, head_width=0.01, color='k') | ||
|
||
ax.text(0.675, 0.14, "tenc") | ||
ax.text(0.687, -0.33, "e") | ||
ax.arrow(0.695, 0.12, 0, -0.05, head_width=0.01, color='k') | ||
ax.arrow(0.695, -0.21, 0, -0.05, head_width=0.01, color='k') | ||
ax.scatter([0.695, 0.695, 0.695, 0.695], [-0.15, -0.10, -0.05, 0.], | ||
marker='o', c='white', edgecolors='k', zorder=2) | ||
rectangle = Rectangle((0.6855, -0.2), 0.02, 0.25, alpha=1, facecolor='none', edgecolor='k') | ||
ax.add_patch(rectangle) | ||
|
||
ax.set_ylim([-0.6, 1.1]) | ||
ax.set_xlim([-0., 0.96]) | ||
plt.axis('off') | ||
plt.tight_layout() | ||
plt.savefig("images/rnn_flow.png", dpi=600, format='png') | ||
plt.close() |
Oops, something went wrong.