Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: sub module conflict error #295

Merged
merged 5 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/parser-cases/foo.bar.thrift
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include "foo/bar.thrift"
Empty file.
1 change: 1 addition & 0 deletions tests/parser-cases/include.thrift
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
include "included.thrift"
include "include/included_1.thrift"

const included.Timestamp datetime = 1422009523
1 change: 1 addition & 0 deletions tests/parser-cases/include/included_1.thrift
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include "included_2.thrift"
Empty file.
8 changes: 5 additions & 3 deletions tests/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ def test_load_struct():
def test_load_union():
assert storm_tt.JavaObjectArg.__base__ == TPayload
assert storm.JavaObjectArg.thrift_spec == \
storm_tt.JavaObjectArg.thrift_spec
storm_tt.JavaObjectArg.thrift_spec


def test_load_exc():
assert ab_tt.PersonNotExistsError.__base__ == TException
assert ab.PersonNotExistsError.thrift_spec == \
ab_tt.PersonNotExistsError.thrift_spec
ab_tt.PersonNotExistsError.thrift_spec


def test_load_service():
Expand All @@ -70,4 +70,6 @@ def test_load_include():
g = load("parent.thrift")

ts = g.Greet.thrift_spec
assert ts[1][2] == b.Hello and ts[2][0] == TType.I64 and ts[3][2] == b.Code
assert (ts[1][2].thrift_spec == b.Hello.thrift_spec and
ts[2][0] == TType.I64 and
ts[3][2]._NAMES_TO_VALUES == b.Code._NAMES_TO_VALUES)
25 changes: 23 additions & 2 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-

import sys
import threading

import pytest
Expand Down Expand Up @@ -36,8 +36,26 @@ def test_constants():

def test_include():
thrift = load('parser-cases/include.thrift', include_dirs=[
'./parser-cases'])
'./parser-cases'], module_name='include_thrift')
assert thrift.datetime == 1422009523
assert sys.modules['include_thrift'] is not None
assert sys.modules['included_thrift'] is not None
assert sys.modules['include.included_1_thrift'] is not None
assert sys.modules['include.included_2_thrift'] is not None


def test_include_with_module_name_prefix():
load('parser-cases/include.thrift', module_name='parser_cases.include_thrift')
assert sys.modules['parser_cases.include_thrift'] is not None
assert sys.modules['parser_cases.included_thrift'] is not None
assert sys.modules['parser_cases.include.included_1_thrift'] is not None
assert sys.modules['parser_cases.include.included_2_thrift'] is not None


def test_include_conflict():
with pytest.raises(ThriftParserError) as excinfo:
load('parser-cases/foo.bar.thrift', module_name='foo.bar_thrift')
assert 'Module name conflict between' in str(excinfo.value)


def test_cpp_include():
Expand Down Expand Up @@ -295,6 +313,9 @@ def test_thrift_meta():


def test_load_fp():
from thriftpy2.parser import threadlocal
threadlocal.__dict__.clear()

thrift = None
with open('parser-cases/shared.thrift') as thrift_fp:
thrift = load_fp(thrift_fp, 'shared_thrift')
Expand Down
23 changes: 16 additions & 7 deletions thriftpy2/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import types

from .parser import parse, parse_fp, threadlocal, _cast
from .exc import ThriftParserError
from .exc import ThriftParserError, ThriftModuleNameConflict
from ..thrift import TPayloadMeta


Expand All @@ -41,12 +41,21 @@ def load(path,
# add sub modules to sys.modules recursively
if real_module:
sys.modules[module_name] = thrift
sub_modules = thrift.__thrift_meta__["includes"][:]
while sub_modules:
module = sub_modules.pop()
if module not in sys.modules:
sys.modules[module.__name__] = module
sub_modules.extend(module.__thrift_meta__["includes"])
include_thrifts = thrift.__thrift_meta__["includes"][:]
while include_thrifts:
include_thrift = include_thrifts.pop()
registered_thrift = sys.modules.get(include_thrift.__thrift_module_name__)
if registered_thrift is None:
sys.modules[include_thrift.__thrift_module_name__] = include_thrift
if hasattr(include_thrift, "__thrift_meta__"):
include_thrifts.extend(
include_thrift.__thrift_meta__["includes"][:])
else:
if registered_thrift.__thrift_file__ != include_thrift.__thrift_file__:
raise ThriftModuleNameConflict(
'Module name conflict between "%s" and "%s"' %
(registered_thrift.__thrift_file__, include_thrift.__thrift_file__)
)
return thrift


Expand Down
4 changes: 4 additions & 0 deletions thriftpy2/parser/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ class ThriftParserError(Exception):
pass


class ThriftModuleNameConflict(ThriftParserError):
pass


class ThriftLexerError(ThriftParserError):
pass

Expand Down
16 changes: 15 additions & 1 deletion thriftpy2/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,21 @@ def p_include(p):
for include_dir in replace_include_dirs:
path = os.path.join(include_dir, p[2])
if os.path.exists(path):
child = parse(path)
thrift_file_name_module = os.path.basename(thrift.__thrift_file__)
if thrift_file_name_module.endswith(".thrift"):
thrift_file_name_module = thrift_file_name_module[:-7] + "_thrift"
module_prefix = str(thrift.__name__).rstrip(thrift_file_name_module)

child_rel_path = os.path.relpath(str(path), os.path.dirname(thrift.__thrift_file__))
child_module_name = str(child_rel_path).replace(os.sep, ".").replace(".thrift", "_thrift")
child_module_name = module_prefix + child_module_name

child = parse(path, module_name=child_module_name)
child_include_module_name = os.path.basename(path)
if child_include_module_name.endswith(".thrift"):
child_include_module_name = child_include_module_name[:-7]
setattr(child, '__name__', child_include_module_name)
setattr(child, '__thrift_module_name__', child_module_name)
setattr(thrift, child.__name__, child)
_add_thrift_meta('includes', child)
return
Expand Down