diff --git a/python/instrumentation/openinference-instrumentation-dspy/pyproject.toml b/python/instrumentation/openinference-instrumentation-dspy/pyproject.toml index 64b3d6143..44da863e4 100644 --- a/python/instrumentation/openinference-instrumentation-dspy/pyproject.toml +++ b/python/instrumentation/openinference-instrumentation-dspy/pyproject.toml @@ -55,3 +55,13 @@ include = [ [tool.hatch.build.targets.wheel] packages = ["src/openinference"] + +[tool.mypy] +plugins = [] +disallow_untyped_calls = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +strict = true +exclude = [ + "dist/", +] diff --git a/python/instrumentation/openinference-instrumentation-dspy/src/openinference/instrumentation/dspy/__init__.py b/python/instrumentation/openinference-instrumentation-dspy/src/openinference/instrumentation/dspy/__init__.py index c0cc5f06f..b519cad90 100644 --- a/python/instrumentation/openinference-instrumentation-dspy/src/openinference/instrumentation/dspy/__init__.py +++ b/python/instrumentation/openinference-instrumentation-dspy/src/openinference/instrumentation/dspy/__init__.py @@ -1,6 +1,6 @@ import json from abc import ABC -from typing import Any, Callable, Collection, Dict, Mapping, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Collection, Dict, Mapping, Tuple from openinference.instrumentation.dspy.package import _instruments from openinference.instrumentation.dspy.version import __version__ @@ -10,16 +10,19 @@ SpanAttributes, ) from opentelemetry import trace as trace_api -from opentelemetry.instrumentation.instrumentor import BaseInstrumentor +from opentelemetry.instrumentation.instrumentor import BaseInstrumentor # type: ignore from wrapt import wrap_function_wrapper +if TYPE_CHECKING: + import dspy + _DSPY_MODULE = "dspy" # DSPy used to be called DSP - some of the modules still fall under the old namespace _DSP_MODULE = "dsp" -class DSPyInstrumentor(BaseInstrumentor): +class DSPyInstrumentor(BaseInstrumentor): # type: ignore """ OpenInference Instrumentor for DSPy """ @@ -27,7 +30,7 @@ class DSPyInstrumentor(BaseInstrumentor): def instrumentation_dependencies(self) -> Collection[str]: return _instruments - def _instrument(self, **kwargs): + def _instrument(self, **kwargs: Any) -> None: if not (tracer_provider := kwargs.get("tracer_provider")): tracer_provider = trace_api.get_tracer_provider() tracer = trace_api.get_tracer(__name__, __version__, tracer_provider) @@ -50,7 +53,7 @@ def _instrument(self, **kwargs): wrapper=_PredictForwardWrapper(tracer), ) - def _uninstrument(self, **kwargs): + def _uninstrument(self, **kwargs: Any) -> None: from dsp.modules.lm import LM language_model_classes = LM.__subclasses__() @@ -86,7 +89,7 @@ def __call__( instance: Any, args: Tuple[type, Any], kwargs: Mapping[str, Any], - ) -> None: + ) -> Any: print("LM.basic_request") prompt = args[0] kwargs = {**instance.kwargs, **kwargs} @@ -130,7 +133,7 @@ def __call__( instance: Any, args: Tuple[type, Any], kwargs: Mapping[str, Any], - ) -> None: + ) -> Any: signature = kwargs.get("signature", instance.signature) span_name = signature.__name__ + ".forward" with self._tracer.start_as_current_span( @@ -159,8 +162,8 @@ def __call__( return prediction def _prediction_to_output_dict( - self, prediction, signature - ) -> Dict[str, Union[str, int, float, bool]]: + self, prediction: dspy.Prediction, signature: dspy.Signature + ) -> Dict[str, Any]: """ Parse the prediction fields to get the input and output fields """ diff --git a/python/instrumentation/openinference-instrumentation-dspy/tests/openinference/instrumentation/dspy/test_instrumentor.py b/python/instrumentation/openinference-instrumentation-dspy/tests/openinference/instrumentation/dspy/test_instrumentor.py index 233b2d5c1..478f3c0bd 100644 --- a/python/instrumentation/openinference-instrumentation-dspy/tests/openinference/instrumentation/dspy/test_instrumentor.py +++ b/python/instrumentation/openinference-instrumentation-dspy/tests/openinference/instrumentation/dspy/test_instrumentor.py @@ -42,7 +42,7 @@ def instrument( def test_openai_lm( in_memory_span_exporter: InMemorySpanExporter, ) -> None: - class BasicQA(dspy.Signature): + class BasicQA(dspy.Signature): # type: ignore """Answer questions with short factoid answers.""" question = dspy.InputField() @@ -84,4 +84,4 @@ class BasicQA(dspy.Signature): chain_span = spans[1] assert chain_span.name == "BasicQA.forward" assert lm_span.name == "GPT3.request" - assert question in lm_span.attributes[SpanAttributes.INPUT_VALUE] + assert question in lm_span.attributes[SpanAttributes.INPUT_VALUE] # type: ignore diff --git a/python/mypy.ini b/python/mypy.ini index c8ed2018d..2d2b6180d 100644 --- a/python/mypy.ini +++ b/python/mypy.ini @@ -9,3 +9,7 @@ exclude = (?x)( [mypy-wrapt] ignore_missing_imports = True +[mypy-dspy.*] +ignore_missing_imports = True +[mypy-dsp.*] +ignore_missing_imports = True