diff --git a/tests/pipeline/test_engine.py b/tests/pipeline/test_engine.py index 401bc30b17..e11d61d977 100644 --- a/tests/pipeline/test_engine.py +++ b/tests/pipeline/test_engine.py @@ -910,11 +910,12 @@ class SyntheticBcolzTestCase(zf.WithAdjustmentReader, first_asset_start = Timestamp('2015-04-01', tz='UTC') START_DATE = Timestamp('2015-01-01', tz='utc') END_DATE = Timestamp('2015-08-01', tz='utc') + ASSET_FINDER_EQUITY_SIDS = list(range(6)) @classmethod def make_equity_info(cls): cls.equity_info = ret = make_rotating_equity_info( - num_assets=6, + sids=cls.ASSET_FINDER_EQUITY_SIDS, first_start=cls.first_asset_start, frequency=cls.trading_calendar.day, periods_between_starts=4, diff --git a/tests/pipeline/test_international_markets.py b/tests/pipeline/test_international_markets.py index 2b0d63f34e..ed00afc870 100644 --- a/tests/pipeline/test_international_markets.py +++ b/tests/pipeline/test_international_markets.py @@ -3,9 +3,10 @@ import numpy as np import pandas as pd -from trading_calendars import get_calendar - -from zipline.assets.synthetic import make_rotating_equity_info +from zipline.assets.synthetic import ( + make_rotating_equity_info, + make_multi_exchange_equity_info, +) from zipline.data.in_memory_daily_bars import InMemoryDailyBarReader from zipline.pipeline.domain import ( CA_EQUITIES, @@ -18,7 +19,10 @@ from zipline.pipeline.loaders.equity_pricing_loader import EquityPricingLoader from zipline.pipeline.loaders.synthetic import NullAdjustmentReader from zipline.testing.predicates import assert_equal -from zipline.testing.core import parameter_space, random_tick_prices +from zipline.testing.core import ( + parameter_space, + random_tick_prices, +) import zipline.testing.fixtures as zf @@ -146,29 +150,26 @@ class InternationalEquityTestCase(WithInternationalPricingPipelineEngine, @classmethod def make_equity_info(cls): - out = pd.concat( - [ - # 15 assets on each exchange. Each asset lives for 5 days. - # A new asset starts each day. - make_rotating_equity_info( - num_assets=20, - first_start=cls.START_DATE, - frequency=get_calendar(exchange).day, - periods_between_starts=1, - # NOTE: The asset_lifetime parameter name is a bit - # misleading. It determines the number of trading - # days between each asset's start_date and end_date, - # so assets created with this method actual "live" - # for (asset_lifetime + 1) days. But, since pipeline - # doesn't show you an asset the day it IPOs, this - # number matches the number of days that each asset - # should appear in a pipeline output. - asset_lifetime=5, - exchange=exchange, - ) - for exchange in cls.EXCHANGE_INFO.exchange - ], - ignore_index=True, + # - 20 assets on each exchange. + # - Each asset lives for 5 days. + # - A new asset starts each day. + out = make_multi_exchange_equity_info( + factory=make_rotating_equity_info, + exchange_sids={ + 'XNYS': range(20), + 'XTSE': range(20, 40), + 'XLON': range(40, 60), + }, + first_start=cls.START_DATE, + periods_between_starts=1, + # NOTE: The asset_lifetime parameter name is a bit misleading. It + # determines the number of trading days between each asset's + # start_date and end_date, so assets created with this method + # actual "live" for (asset_lifetime + 1) days. But, since + # pipeline doesn't show you an asset the day it IPOs, this + # number matches the number of days that each asset should + # appear in a pipeline output. + asset_lifetime=5, ) assert_equal(out.end_date.max(), cls.END_DATE) return out @@ -211,7 +212,7 @@ def test_generic_pipeline_with_explicit_domain(self, domain): expected_dates = sessions[-17:-9] for col in pipe.columns: - # result_date should look like this: + # result_data should look like this: # # E F G H I J K L M N O P # noqa # 24.17 25.17 26.17 27.17 28.17 NaN NaN NaN NaN NaN NaN NaN # noqa diff --git a/tests/test_assets.py b/tests/test_assets.py index a7d454f4c5..af1c41909a 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -1056,19 +1056,22 @@ def test_compute_lifetimes(self): equities = pd.concat( [ make_rotating_equity_info( - num_assets=assets_per_exchange, + sids=range( + i * assets_per_exchange, + (i + 1) * assets_per_exchange, + ), first_start=first_start, frequency=trading_day, periods_between_starts=3, asset_lifetime=5, exchange=exchange, ) - for exchange in ( + for i, exchange in enumerate(( 'US_EXCHANGE_1', 'US_EXCHANGE_2', 'CA_EXCHANGE', 'JP_EXCHANGE', - ) + )) ], ignore_index=True, ) @@ -1076,7 +1079,7 @@ def test_compute_lifetimes(self): equities['symbol'] = list(string.ascii_uppercase[:len(equities)]) # shuffle up the sids so they are not contiguous per exchange - sids = np.arange(len(equities)) + sids = equities.index.values[:] np.random.RandomState(1337).shuffle(sids) equities.index = sids permute_sid = dict(zip(sids, range(len(sids)))).__getitem__ diff --git a/zipline/assets/synthetic.py b/zipline/assets/synthetic.py index 390387050f..1e9caf7ca4 100644 --- a/zipline/assets/synthetic.py +++ b/zipline/assets/synthetic.py @@ -4,11 +4,13 @@ import pandas as pd from pandas.tseries.offsets import MonthBegin from six import iteritems +from toolz import merge +from trading_calendars import get_calendar from .futures import CMES_CODE_TO_MONTH -def make_rotating_equity_info(num_assets, +def make_rotating_equity_info(sids, first_start, frequency, periods_between_starts, @@ -38,6 +40,7 @@ def make_rotating_equity_info(num_assets, info : pd.DataFrame DataFrame representing newly-created assets. """ + num_assets = len(sids) return pd.DataFrame( { 'symbol': [chr(ord('A') + i) for i in range(num_assets)], @@ -55,7 +58,7 @@ def make_rotating_equity_info(num_assets, ), 'exchange': exchange, }, - index=range(num_assets), + index=sids, ) @@ -117,12 +120,13 @@ def make_simple_equity_info(sids, ) -def make_jagged_equity_info(num_assets, +def make_jagged_equity_info(sids, start_date, first_end, frequency, periods_between_ends, - auto_close_delta): + auto_close_delta, + exchange='TEST'): """ Create a DataFrame representing assets that all begin at the same start date, but have cascading end dates. @@ -146,6 +150,7 @@ def make_jagged_equity_info(num_assets, info : pd.DataFrame DataFrame representing newly-created assets. """ + num_assets = len(sids) frame = pd.DataFrame( { 'symbol': [chr(ord('A') + i) for i in range(num_assets)], @@ -155,9 +160,9 @@ def make_jagged_equity_info(num_assets, freq=(periods_between_ends * frequency), periods=num_assets, ), - 'exchange': 'TEST', + 'exchange': exchange, }, - index=range(num_assets), + index=sids, ) # Explicitly pass None to disable setting the auto_close_date column. @@ -167,6 +172,61 @@ def make_jagged_equity_info(num_assets, return frame +def make_multi_exchange_equity_info(factory, + exchange_sids, + exchange_kwargs=None, + **common_kwargs): + """ + Create an "equity_info" DataFrame for multiple exchanges by calling an + existing factory function for each exchange and concatting the results. + + Parameters + ---------- + factory : function + Function to use to create equity info for each exchange. + exchange_sids : dict[str -> list[sids]] + Map from exchange to list of sids to be created for that exchange. + exchange_kwargs : dict[str -> dict], optional + Map from exchange to additional kwargs to be passed for that exchange. + **common_kwargs + Additional keyword-arguments are forwarded to ``factory``. + + Returns + ------- + info : pd.DataFrame + DataFrame representing newly-created assets. + """ + if exchange_kwargs is None: + exchange_kwargs = {e: {} for e in exchange_sids} + else: + assert exchange_kwargs.keys() == exchange_sids.keys() + + # When using frequency-based factories, use each calendar's trading + # calendar for frequency by default. + provide_default_frequency = ( + 'frequency' not in common_kwargs + and factory in (make_rotating_equity_info, make_jagged_equity_info) + ) + if provide_default_frequency: + for e, kw in iteritems(exchange_kwargs): + kw.setdefault('frequency', get_calendar(e).day) + + frame_per_exchange = [ + factory( + sids=sids, + exchange=e, + **merge(common_kwargs, exchange_kwargs[e]) + ) + for e, sids in iteritems(exchange_sids) + ] + + result = pd.concat(frame_per_exchange) + if not result.index.is_unique: + raise AssertionError("Duplicate sids: {}".format(result.index)) + + return result + + def make_future_info(first_sid, root_symbols, years,