Skip to content

Commit

Permalink
update overridden fixes after sdk bump
Browse files Browse the repository at this point in the history
  • Loading branch information
pnadolny13 committed Jan 18, 2024
1 parent 4918b2c commit 8e28ea3
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 32 deletions.
20 changes: 6 additions & 14 deletions map_gpt_embeddings/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
from langchain.text_splitter import RecursiveCharacterTextSplitter
from singer_sdk import exceptions
from singer_sdk import typing as th
from singer_sdk._singerlib.messages import Message, SchemaMessage
from singer_sdk._singerlib.messages import Message, SchemaMessage, RecordMessage

from map_gpt_embeddings.sdk_fixes.mapper_base import BasicPassthroughMapper
from map_gpt_embeddings.sdk_fixes.messages import RecordMessage
from map_gpt_embeddings.stream import OpenAIStream
from map_gpt_embeddings.tap import TapOpenAI

Expand Down Expand Up @@ -76,27 +75,19 @@ def map_schema_message(self, message_dict: dict) -> t.Iterable[Message]:
),
).to_dict()

def _validate_config(
self,
*,
raise_errors: bool = True,
warnings_as_errors: bool = False,
) -> tuple[list[str], list[str]]:
def _validate_config(self, *, raise_errors: bool = True) -> list[str]:
"""Validate configuration input against the plugin configuration JSON schema.
Args:
raise_errors: Flag to throw an exception if any validation errors are found.
warnings_as_errors: Flag to throw an exception if any warnings were emitted.
Returns:
A tuple of configuration validation warnings and errors.
A list of validation errors.
Raises:
ConfigValidationError: If raise_errors is True and validation fails.
"""
warnings, errors = super()._validate_config(
raise_errors=raise_errors, warnings_as_errors=warnings_as_errors
)
errors = super()._validate_config(raise_errors=raise_errors)
if (
raise_errors
and self.config.get("openai_api_key", None) is None
Expand All @@ -107,7 +98,8 @@ def _validate_config(
f"`{self.name.upper().replace('-', '_')}_OPEN_API_KEY` env var, or "
" `OPENAI_API_KEY` env var."
)
return warnings, errors

return errors

def split_record(self, record: dict) -> t.Iterable[dict]:
"""Split a record dict to zero or more record dicts.
Expand Down
3 changes: 1 addition & 2 deletions map_gpt_embeddings/sdk_fixes/mapper_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
Message,
SchemaMessage,
StateMessage,
RecordMessage,
)
from singer_sdk.mapper_base import InlineMapper

from map_gpt_embeddings.sdk_fixes.messages import RecordMessage


class BasicPassthroughMapper(InlineMapper):
"""A mapper to split documents into document segments."""
Expand Down
16 changes: 0 additions & 16 deletions map_gpt_embeddings/sdk_fixes/messages.py

This file was deleted.

0 comments on commit 8e28ea3

Please sign in to comment.