diff --git a/tests/parser-cases/foo.bar.thrift b/tests/parser-cases/foo.bar.thrift new file mode 100644 index 0000000..d9b3174 --- /dev/null +++ b/tests/parser-cases/foo.bar.thrift @@ -0,0 +1 @@ +include "foo/bar.thrift" \ No newline at end of file diff --git a/tests/parser-cases/foo/bar.thrift b/tests/parser-cases/foo/bar.thrift new file mode 100644 index 0000000..e69de29 diff --git a/tests/parser-cases/include.thrift b/tests/parser-cases/include.thrift index 14678cf..37dbc8b 100644 --- a/tests/parser-cases/include.thrift +++ b/tests/parser-cases/include.thrift @@ -1,3 +1,4 @@ include "included.thrift" +include "include/included_1.thrift" const included.Timestamp datetime = 1422009523 diff --git a/tests/parser-cases/include/included_1.thrift b/tests/parser-cases/include/included_1.thrift new file mode 100644 index 0000000..a803db8 --- /dev/null +++ b/tests/parser-cases/include/included_1.thrift @@ -0,0 +1 @@ +include "included_2.thrift" \ No newline at end of file diff --git a/tests/parser-cases/include/included_2.thrift b/tests/parser-cases/include/included_2.thrift new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_loader.py b/tests/test_loader.py index 5892235..6c9c36a 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -51,13 +51,13 @@ def test_load_struct(): def test_load_union(): assert storm_tt.JavaObjectArg.__base__ == TPayload assert storm.JavaObjectArg.thrift_spec == \ - storm_tt.JavaObjectArg.thrift_spec + storm_tt.JavaObjectArg.thrift_spec def test_load_exc(): assert ab_tt.PersonNotExistsError.__base__ == TException assert ab.PersonNotExistsError.thrift_spec == \ - ab_tt.PersonNotExistsError.thrift_spec + ab_tt.PersonNotExistsError.thrift_spec def test_load_service(): @@ -70,4 +70,6 @@ def test_load_include(): g = load("parent.thrift") ts = g.Greet.thrift_spec - assert ts[1][2] == b.Hello and ts[2][0] == TType.I64 and ts[3][2] == b.Code + assert (ts[1][2].thrift_spec == b.Hello.thrift_spec and + ts[2][0] == TType.I64 and + ts[3][2]._NAMES_TO_VALUES == b.Code._NAMES_TO_VALUES) diff --git a/tests/test_parser.py b/tests/test_parser.py index cb51a71..ea810dd 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- - +import sys import threading import pytest @@ -36,8 +36,26 @@ def test_constants(): def test_include(): thrift = load('parser-cases/include.thrift', include_dirs=[ - './parser-cases']) + './parser-cases'], module_name='include_thrift') assert thrift.datetime == 1422009523 + assert sys.modules['include_thrift'] is not None + assert sys.modules['included_thrift'] is not None + assert sys.modules['include.included_1_thrift'] is not None + assert sys.modules['include.included_2_thrift'] is not None + + +def test_include_with_module_name_prefix(): + load('parser-cases/include.thrift', module_name='parser_cases.include_thrift') + assert sys.modules['parser_cases.include_thrift'] is not None + assert sys.modules['parser_cases.included_thrift'] is not None + assert sys.modules['parser_cases.include.included_1_thrift'] is not None + assert sys.modules['parser_cases.include.included_2_thrift'] is not None + + +def test_include_conflict(): + with pytest.raises(ThriftParserError) as excinfo: + load('parser-cases/foo.bar.thrift', module_name='foo.bar_thrift') + assert 'Module name conflict between' in str(excinfo.value) def test_cpp_include(): diff --git a/thriftpy2/parser/__init__.py b/thriftpy2/parser/__init__.py index 8930ce8..2f4ff2c 100644 --- a/thriftpy2/parser/__init__.py +++ b/thriftpy2/parser/__init__.py @@ -41,12 +41,25 @@ def load(path, # add sub modules to sys.modules recursively if real_module: sys.modules[module_name] = thrift - sub_modules = thrift.__thrift_meta__["includes"][:] - while sub_modules: - module = sub_modules.pop() - if module not in sys.modules: - sys.modules[module.__name__] = module - sub_modules.extend(module.__thrift_meta__["includes"]) + include_thrifts = list(zip(thrift.__thrift_meta__["includes"][:], + thrift.__thrift_meta__["sub_modules"][:])) + while include_thrifts: + include_thrift = include_thrifts.pop() + registered_thrift = sys.modules.get(include_thrift[1].__name__) + if registered_thrift is None: + sys.modules[include_thrift[1].__name__] = include_thrift[0] + if hasattr(include_thrift[0], "__thrift_meta__"): + include_thrifts.extend( + list( + zip( + include_thrift[0].__thrift_meta__["includes"], + include_thrift[0].__thrift_meta__["sub_modules"]))) + else: + if registered_thrift.__thrift_file__ != include_thrift[0].__thrift_file__: + raise ThriftParserError( + 'Module name conflict between "%s" and "%s"' % + (registered_thrift.__thrift_file__, include_thrift[0].__thrift_file__) + ) return thrift diff --git a/thriftpy2/parser/parser.py b/thriftpy2/parser/parser.py index 9ad2ebd..d1bc627 100644 --- a/thriftpy2/parser/parser.py +++ b/thriftpy2/parser/parser.py @@ -62,9 +62,23 @@ def p_include(p): for include_dir in replace_include_dirs: path = os.path.join(include_dir, p[2]) if os.path.exists(path): - child = parse(path) + thrift_file_name_module = os.path.basename(thrift.__thrift_file__) + if thrift_file_name_module.endswith(".thrift"): + thrift_file_name_module = thrift_file_name_module[:-7] + "_thrift" + module_prefix = str(thrift.__name__).rstrip(thrift_file_name_module) + + child_rel_path = os.path.relpath(str(path), os.path.dirname(thrift.__thrift_file__)) + child_module_name = str(child_rel_path).replace(os.sep, ".").replace(".thrift", "_thrift") + child_module_name = module_prefix + child_module_name + + child = parse(path, module_name=child_module_name) + child_include_module_name = os.path.basename(path) + if child_include_module_name.endswith(".thrift"): + child_include_module_name = child_include_module_name[:-7] + child.__name__ = child_include_module_name setattr(thrift, child.__name__, child) _add_thrift_meta('includes', child) + _add_thrift_meta('sub_modules', types.ModuleType(child_module_name)) return raise ThriftParserError(('Couldn\'t include thrift %s in any ' 'directories provided') % p[2])