Skip to content

Commit

Permalink
ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeldking committed Jan 23, 2024
1 parent 61b3aca commit 007a146
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,13 @@ include = [

[tool.hatch.build.targets.wheel]
packages = ["src/openinference"]

[tool.mypy]
plugins = []
disallow_untyped_calls = true
disallow_untyped_defs = true
disallow_incomplete_defs = true
strict = true
exclude = [
"dist/",
]
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from abc import ABC
from typing import Any, Callable, Collection, Dict, Mapping, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Collection, Dict, Mapping, Tuple

from openinference.instrumentation.dspy.package import _instruments
from openinference.instrumentation.dspy.version import __version__
Expand All @@ -10,24 +10,27 @@
SpanAttributes,
)
from opentelemetry import trace as trace_api
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor # type: ignore
from wrapt import wrap_function_wrapper

if TYPE_CHECKING:
import dspy

_DSPY_MODULE = "dspy"

# DSPy used to be called DSP - some of the modules still fall under the old namespace
_DSP_MODULE = "dsp"


class DSPyInstrumentor(BaseInstrumentor):
class DSPyInstrumentor(BaseInstrumentor): # type: ignore
"""
OpenInference Instrumentor for DSPy
"""

def instrumentation_dependencies(self) -> Collection[str]:
return _instruments

def _instrument(self, **kwargs):
def _instrument(self, **kwargs: Any) -> None:
if not (tracer_provider := kwargs.get("tracer_provider")):
tracer_provider = trace_api.get_tracer_provider()
tracer = trace_api.get_tracer(__name__, __version__, tracer_provider)
Expand All @@ -50,7 +53,7 @@ def _instrument(self, **kwargs):
wrapper=_PredictForwardWrapper(tracer),
)

def _uninstrument(self, **kwargs):
def _uninstrument(self, **kwargs: Any) -> None:
from dsp.modules.lm import LM

language_model_classes = LM.__subclasses__()
Expand Down Expand Up @@ -86,7 +89,7 @@ def __call__(
instance: Any,
args: Tuple[type, Any],
kwargs: Mapping[str, Any],
) -> None:
) -> Any:
print("LM.basic_request")
prompt = args[0]
kwargs = {**instance.kwargs, **kwargs}
Expand Down Expand Up @@ -130,7 +133,7 @@ def __call__(
instance: Any,
args: Tuple[type, Any],
kwargs: Mapping[str, Any],
) -> None:
) -> Any:
signature = kwargs.get("signature", instance.signature)
span_name = signature.__name__ + ".forward"
with self._tracer.start_as_current_span(
Expand Down Expand Up @@ -159,8 +162,8 @@ def __call__(
return prediction

def _prediction_to_output_dict(
self, prediction, signature
) -> Dict[str, Union[str, int, float, bool]]:
self, prediction: dspy.Prediction, signature: dspy.Signature
) -> Dict[str, Any]:
"""
Parse the prediction fields to get the input and output fields
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def instrument(
def test_openai_lm(
in_memory_span_exporter: InMemorySpanExporter,
) -> None:
class BasicQA(dspy.Signature):
class BasicQA(dspy.Signature): # type: ignore
"""Answer questions with short factoid answers."""

question = dspy.InputField()
Expand Down Expand Up @@ -84,4 +84,4 @@ class BasicQA(dspy.Signature):
chain_span = spans[1]
assert chain_span.name == "BasicQA.forward"
assert lm_span.name == "GPT3.request"
assert question in lm_span.attributes[SpanAttributes.INPUT_VALUE]
assert question in lm_span.attributes[SpanAttributes.INPUT_VALUE] # type: ignore
4 changes: 4 additions & 0 deletions python/mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ exclude = (?x)(

[mypy-wrapt]
ignore_missing_imports = True
[mypy-dspy.*]
ignore_missing_imports = True
[mypy-dsp.*]
ignore_missing_imports = True

0 comments on commit 007a146

Please sign in to comment.