diff --git a/tests/test_calendar_dispatcher.py b/tests/test_calendar_dispatcher.py index a6896cf9..d37a7b08 100644 --- a/tests/test_calendar_dispatcher.py +++ b/tests/test_calendar_dispatcher.py @@ -10,6 +10,9 @@ from trading_calendars.calendar_utils import TradingCalendarDispatcher from trading_calendars.exchange_calendar_iepa import IEPAExchangeCalendar +import pandas as pd +import pytz + class CalendarAliasTestCase(TestCase): @@ -18,8 +21,8 @@ def setupClass(cls): # Make a calendar once so that we don't spend time in every test # instantiating calendars. cls.dispatcher_kwargs = dict( - calendars={'IEPA': IEPAExchangeCalendar()}, - calendar_factories={}, + calendars={}, + calendar_factories={'IEPA': IEPAExchangeCalendar}, aliases={ 'IEPA_ALIAS': 'IEPA', 'IEPA_ALIAS_ALIAS': 'IEPA_ALIAS', @@ -104,3 +107,10 @@ def test_get_calendar_names(self): sorted(self.dispatcher.get_calendar_names()), ['IEPA', 'IEPA_ALIAS', 'IEPA_ALIAS_ALIAS'] ) + + def test_kwarg_passing(self): + start_date = pd.Timestamp('2000-01-03', tz=pytz.UTC) + cal = self.dispatcher.get_calendar( + 'IEPA', start=start_date + ) + self.assertEqual(cal.first_session, start_date) diff --git a/trading_calendars/calendar_utils.py b/trading_calendars/calendar_utils.py index f426f4cf..6d6d664a 100644 --- a/trading_calendars/calendar_utils.py +++ b/trading_calendars/calendar_utils.py @@ -162,7 +162,7 @@ def __init__(self, calendars, calendar_factories, aliases): self._calendar_factories = dict(calendar_factories) self._aliases = dict(aliases) - def get_calendar(self, name): + def get_calendar(self, name, **kwargs): """ Retrieves an instance of an TradingCalendar whose name is given. @@ -170,7 +170,10 @@ def get_calendar(self, name): ---------- name : str The name of the TradingCalendar to be retrieved. - + kwargs: Dict[str, Any] + Optional keyword args passed to calendar `__init__. + Note: Arguments are only passed when the calendar is first + constructed. Subsequent calls will not affect the calendar. Returns ------- calendar : calendars.TradingCalendar @@ -191,7 +194,7 @@ def get_calendar(self, name): raise InvalidCalendarName(calendar_name=name) # Cache the calendar for future use. - calendar = self._calendars[canonical_name] = factory() + calendar = self._calendars[canonical_name] = factory(**kwargs) return calendar def get_calendar_names(self):