Skip to content

Commit

Permalink
fix: sub module conflict error
Browse files Browse the repository at this point in the history
  • Loading branch information
StellarisW committed Nov 29, 2024
1 parent 8e226b1 commit 08b5ac6
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 12 deletions.
1 change: 1 addition & 0 deletions tests/parser-cases/foo.bar.thrift
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include "foo/bar.thrift"
Empty file.
1 change: 1 addition & 0 deletions tests/parser-cases/include.thrift
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
include "included.thrift"
include "include/included_1.thrift"

const included.Timestamp datetime = 1422009523
1 change: 1 addition & 0 deletions tests/parser-cases/include/included_1.thrift
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include "included_2.thrift"
Empty file.
8 changes: 5 additions & 3 deletions tests/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ def test_load_struct():
def test_load_union():
assert storm_tt.JavaObjectArg.__base__ == TPayload
assert storm.JavaObjectArg.thrift_spec == \
storm_tt.JavaObjectArg.thrift_spec
storm_tt.JavaObjectArg.thrift_spec


def test_load_exc():
assert ab_tt.PersonNotExistsError.__base__ == TException
assert ab.PersonNotExistsError.thrift_spec == \
ab_tt.PersonNotExistsError.thrift_spec
ab_tt.PersonNotExistsError.thrift_spec


def test_load_service():
Expand All @@ -70,4 +70,6 @@ def test_load_include():
g = load("parent.thrift")

ts = g.Greet.thrift_spec
assert ts[1][2] == b.Hello and ts[2][0] == TType.I64 and ts[3][2] == b.Code
assert (ts[1][2].thrift_spec == b.Hello.thrift_spec and
ts[2][0] == TType.I64 and
ts[3][2]._NAMES_TO_VALUES == b.Code._NAMES_TO_VALUES)
22 changes: 20 additions & 2 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-

import sys
import threading

import pytest
Expand Down Expand Up @@ -36,8 +36,26 @@ def test_constants():

def test_include():
thrift = load('parser-cases/include.thrift', include_dirs=[
'./parser-cases'])
'./parser-cases'], module_name='include_thrift')
assert thrift.datetime == 1422009523
assert sys.modules['include_thrift'] is not None
assert sys.modules['included_thrift'] is not None
assert sys.modules['include.included_1_thrift'] is not None
assert sys.modules['include.included_2_thrift'] is not None


def test_include_with_module_name_prefix():
load('parser-cases/include.thrift', module_name='parser_cases.include_thrift')
assert sys.modules['parser_cases.include_thrift'] is not None
assert sys.modules['parser_cases.included_thrift'] is not None
assert sys.modules['parser_cases.include.included_1_thrift'] is not None
assert sys.modules['parser_cases.include.included_2_thrift'] is not None


def test_include_conflict():
with pytest.raises(ThriftParserError) as excinfo:
load('parser-cases/foo.bar.thrift', module_name='foo.bar_thrift')
assert 'Module name conflict between' in str(excinfo.value)


def test_cpp_include():
Expand Down
25 changes: 19 additions & 6 deletions thriftpy2/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,25 @@ def load(path,
# add sub modules to sys.modules recursively
if real_module:
sys.modules[module_name] = thrift
sub_modules = thrift.__thrift_meta__["includes"][:]
while sub_modules:
module = sub_modules.pop()
if module not in sys.modules:
sys.modules[module.__name__] = module
sub_modules.extend(module.__thrift_meta__["includes"])
include_thrifts = list(zip(thrift.__thrift_meta__["includes"][:],
thrift.__thrift_meta__["sub_modules"][:]))
while include_thrifts:
include_thrift = include_thrifts.pop()
registered_thrift = sys.modules.get(include_thrift[1].__name__)
if registered_thrift is None:
sys.modules[include_thrift[1].__name__] = include_thrift[0]
if hasattr(include_thrift[0], "__thrift_meta__"):
include_thrifts.extend(
list(
zip(
include_thrift[0].__thrift_meta__["includes"],
include_thrift[0].__thrift_meta__["sub_modules"])))
else:
if registered_thrift.__thrift_file__ != include_thrift[0].__thrift_file__:
raise ThriftParserError(
'Module name conflict between "%s" and "%s"' %
(registered_thrift.__thrift_file__, include_thrift[0].__thrift_file__)
)
return thrift


Expand Down
16 changes: 15 additions & 1 deletion thriftpy2/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,23 @@ def p_include(p):
for include_dir in replace_include_dirs:
path = os.path.join(include_dir, p[2])
if os.path.exists(path):
child = parse(path)
thrift_file_name_module = os.path.basename(thrift.__thrift_file__)
if thrift_file_name_module.endswith(".thrift"):
thrift_file_name_module = thrift_file_name_module[:-7] + "_thrift"
module_prefix = str(thrift.__name__).rstrip(thrift_file_name_module)

child_rel_path = os.path.relpath(str(path), os.path.dirname(thrift.__thrift_file__))
child_module_name = str(child_rel_path).replace(os.sep, ".").replace(".thrift", "_thrift")
child_module_name = module_prefix + child_module_name

child = parse(path, module_name=child_module_name)
child_include_module_name = os.path.basename(path)
if child_include_module_name.endswith(".thrift"):
child_include_module_name = child_include_module_name[:-7]
child.__name__ = child_include_module_name
setattr(thrift, child.__name__, child)
_add_thrift_meta('includes', child)
_add_thrift_meta('sub_modules', types.ModuleType(child_module_name))
return
raise ThriftParserError(('Couldn\'t include thrift %s in any '
'directories provided') % p[2])
Expand Down

0 comments on commit 08b5ac6

Please sign in to comment.