diff --git a/synthpop/synthesizer.py b/synthpop/synthesizer.py index 9123c1f..67a570e 100644 --- a/synthpop/synthesizer.py +++ b/synthpop/synthesizer.py @@ -5,6 +5,8 @@ import numpy as np import pandas as pd from scipy.stats import chisquare +from tqdm import tqdm +from concurrent.futures import ProcessPoolExecutor from . import categorizer as cat from . import draw @@ -75,6 +77,7 @@ def synthesize(h_marg, p_marg, h_jd, p_jd, h_pums, p_pums, geography, ignore_max ignore_max_iters) logger.info("Time to run ipu: %.3fs" % (time.time()-t1)) + logger.info("Time to run ipu: %.3fs" % (time.time() - t1)) logger.debug("IPU weights:") logger.debug(best_weights.describe()) logger.debug("Fit quality:") @@ -91,6 +94,100 @@ def synthesize(h_marg, p_marg, h_jd, p_jd, h_pums, p_pums, geography, ignore_max num_households, h_pums, p_pums, household_freq, h_constraint, p_constraint, best_weights, hh_index_start=hh_index_start) +def geog_preprocessing(geog_id, recipe, marginal_zero_sub, jd_zero_sub, + hh_index_start): + h_marg = recipe.get_household_marginal_for_geography(geog_id) + logger.debug("Household marginal") + logger.debug(h_marg) + + p_marg = recipe.get_person_marginal_for_geography(geog_id) + logger.debug("Person marginal") + logger.debug(p_marg) + + h_pums, h_jd = recipe.\ + get_household_joint_dist_for_geography(geog_id) + logger.debug("Household joint distribution") + logger.debug(h_jd) + + p_pums, p_jd = recipe.get_person_joint_dist_for_geography(geog_id) + logger.debug("Person joint distribution") + logger.debug(p_jd) + + return h_marg, p_marg, h_jd, p_jd, h_pums, p_pums, marginal_zero_sub,\ + jd_zero_sub, hh_index_start + +def synthesize_all_in_parallel( + recipe, num_geogs=None, indexes=None, marginal_zero_sub=.01, + jd_zero_sub=.001, max_workers=5, hh_index_start=0): + """ + Returns + ------- + households, people : pandas.DataFrame + fit_quality : dict of FitQuality + Keys are geographic IDs, values are namedtuples with attributes + ``.household_chisq``, ``household_p``, ``people_chisq``, + and ``people_p``. + """ + with ProcessPoolExecutor(max_workers) as ex: + if indexes is None: + indexes = recipe.get_available_geography_ids() + + hh_list = [] + people_list = [] + cnt = 0 + fit_quality = {} + geog_synth_args = [] + finished_args = [] + geog_ids = [] + futures = [] + + print('Submitting function args for parallel processing:') + for i, geog_id in enumerate(indexes): + geog_synth_args.append(ex.submit( + geog_preprocessing, geog_id, recipe, marginal_zero_sub, + jd_zero_sub, hh_index_start)) + geog_ids.append(geog_id) + cnt += 1 + if num_geogs is not None and cnt >= num_geogs: + break + + + with ProcessPoolExecutor(max_workers=5) as ex: + futures = [ + ex.submit(synthesize, *geog_args.result()) for geog_args in geog_synth_args] + + print('Processing results:') + for i, future in tqdm(enumerate(futures), total=len(futures)): + geog_id = geog_ids[i] + print ('Processing results for: ', geog_id) + try: + households, people, people_chisq, people_p = future.result() + except: + raise ValueError('The synthesis failed for geog_id: {}'.format(geog_id)) + else: + # Append location identifiers to the synthesized households + for geog_cat in geog_id.keys(): + households[geog_cat] = geog_id[geog_cat] + + # update the household_ids since we can't do it in the call to + # synthesize when we execute in parallel + households.index += hh_index_start + people.hh_id += hh_index_start + + hh_list.append(households) + people_list.append(people) + key = BlockGroupID( + geog_id['state'], geog_id['county'], geog_id['tract'], + geog_id['block group']) + fit_quality[key] = FitQuality(people_chisq, people_p) + + if len(households) > 0: + hh_index_start = households.index.values[-1] + 1 + + all_households = pd.concat(hh_list) + all_persons = pd.concat(people_list, ignore_index=True) + + return (all_households, all_persons, fit_quality) def synthesize_all(recipe, num_geogs=None, indexes=None, ignore_max_iters=False, marginal_zero_sub=.01, jd_zero_sub=.001): @@ -116,8 +213,7 @@ def synthesize_all(recipe, num_geogs=None, indexes=None, ignore_max_iters=False, fit_quality = {} hh_index_start = 0 - # TODO will parallelization work here? - for geog_id in indexes: + for geog_id in tqdm(indexes, total=num_geogs): print("Synthesizing geog id:\n", geog_id) h_marg = recipe.get_household_marginal_for_geography(geog_id)