Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fixes typing issues discovered from github api generation #402

Merged
merged 3 commits into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions packages/abstractions/kiota_abstractions/request_adapter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import datetime
from datetime import date, datetime, time, timedelta
from io import BytesIO
from typing import Dict, Generic, List, Optional, TypeVar, Union
from uuid import UUID

from .request_information import RequestInformation
from .serialization import Parsable, ParsableFactory, SerializationWriterFactory
Expand All @@ -11,6 +12,9 @@
ResponseType = TypeVar("ResponseType")
ModelType = TypeVar("ModelType", bound=Parsable)
RequestType = TypeVar("RequestType")
PrimitiveType = TypeVar(
"PrimitiveType", bool, str, int, float, UUID, datetime, timedelta, date, time, bytes
)


class RequestAdapter(ABC, Generic[RequestType]):
Expand Down Expand Up @@ -75,21 +79,21 @@ async def send_collection_async(
async def send_collection_of_primitive_async(
self,
request_info: RequestInformation,
response_type: ResponseType,
response_type: type[PrimitiveType],
error_map: Optional[Dict[str, type[ParsableFactory]]],
) -> Optional[List[ResponseType]]:
) -> Optional[List[PrimitiveType]]:
"""Excutes the HTTP request specified by the given RequestInformation and returns the
deserialized response model collection.

Args:
request_info (RequestInformation): the request info to execute.
response_type (ResponseType): the class of the response model to deserialize the
response_type (PrimitiveType): the class of the response model to deserialize the
response into.
error_map (Optional[Dict[str, type[ParsableFactory]]]): the error dict to use in
case of a failed request.

Returns:
Optional[List[ModelType]]: The deserialized response model collection.
Optional[List[PrimitiveType]]: The deserialized primitive collection.
"""
pass

Expand Down
17 changes: 6 additions & 11 deletions packages/abstractions/kiota_abstractions/request_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from .request_adapter import RequestAdapter

Url = str
T = TypeVar("T", bound=Parsable)
T = TypeVar("T", bool, str, int, float, UUID, datetime, timedelta, date, time, bytes)
U = TypeVar("U", bound=Parsable)
QueryParameters = TypeVar('QueryParameters')
OBSERVABILITY_TRACER_NAME = "microsoft-python-kiota-abstractions"
tracer = trace.get_tracer(OBSERVABILITY_TRACER_NAME, VERSION)
Expand Down Expand Up @@ -155,20 +156,20 @@ def set_content_from_parsable(
self,
request_adapter: RequestAdapter,
content_type: str,
values: Union[T, List[T]],
values: Union[U, List[U]],
) -> None:
"""Sets the request body from a model with the specified content type.

Args:
request_adapter (Optional[RequestAdapter]): The adapter service to get the serialization
writer from.
content_type (Optional[str]): the content type.
values (Union[T, List[T]]): the models.
values (Union[U, List[U]]): the models.
"""
with tracer.start_as_current_span(
self._create_parent_span_name("set_content_from_parsable")
) as span:
writer = self._get_serialization_writer(request_adapter, content_type, values, span)
writer = self._get_serialization_writer(request_adapter, content_type, span)
if isinstance(values, MultipartBody):
content_type += f"; boundary={values.boundary}"
values.request_adapter = request_adapter
Expand Down Expand Up @@ -198,7 +199,7 @@ def set_content_from_scalar(
with tracer.start_as_current_span(
self._create_parent_span_name("set_content_from_scalar")
) as span:
writer = self._get_serialization_writer(request_adapter, content_type, values, span)
writer = self._get_serialization_writer(request_adapter, content_type, span)

if isinstance(values, list):
writer.writer = writer.write_collection_of_primitive_values(None, values)
Expand Down Expand Up @@ -255,15 +256,13 @@ def _get_serialization_writer(
self,
request_adapter: Optional["RequestAdapter"],
content_type: Optional[str],
values: Union[T, List[T]],
parent_span: trace.Span,
):
"""_summary_

Args:
request_adapter (RequestAdapter): _description_
content_type (str): _description_
values (Union[T, List[T]]): _description_
"""
_span = self._start_local_tracing_span("_get_serialization_writer", parent_span)
try:
Expand All @@ -275,10 +274,6 @@ def _get_serialization_writer(
exc = ValueError("Content Type cannot be null")
_span.record_exception(exc)
raise exc
if not values:
exc = ValueError("Values cannot be null")
_span.record_exception(exc)
raise exc
return request_adapter.get_serialization_writer_factory(
).get_serialization_writer(content_type)
finally:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_time_value(self) -> Optional[time]:
pass

@abstractmethod
def get_collection_of_primitive_values(self, primitive_type) -> Optional[List[T]]:
def get_collection_of_primitive_values(self, primitive_type: type[T]) -> Optional[List[T]]:
"""Gets the collection of primitive values of the node
Args:
primitive_type: The type of primitive to return.
Expand All @@ -128,7 +128,7 @@ def get_collection_of_primitive_values(self, primitive_type) -> Optional[List[T]
pass

@abstractmethod
def get_collection_of_object_values(self, factory: ParsableFactory) -> Optional[List[U]]:
def get_collection_of_object_values(self, factory: ParsableFactory[U]) -> Optional[List[U]]:
"""Gets the collection of model object values of the node
Args:
factory (ParsableFactory): The factory to use to create the model object.
Expand Down
10 changes: 5 additions & 5 deletions packages/http/httpx/kiota_http/httpx_request_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from kiota_abstractions.api_error import APIError
from kiota_abstractions.authentication import AuthenticationProvider
from kiota_abstractions.request_adapter import RequestAdapter, ResponseType
from kiota_abstractions.request_adapter import RequestAdapter, ResponseType, PrimitiveType
from kiota_abstractions.request_information import RequestInformation
from kiota_abstractions.serialization import (
Parsable,
Expand Down Expand Up @@ -250,20 +250,20 @@ async def send_collection_async(
async def send_collection_of_primitive_async(
self,
request_info: RequestInformation,
response_type: ResponseType,
response_type: type[PrimitiveType],
error_map: Optional[Dict[str, type[ParsableFactory]]],
) -> Optional[List[ResponseType]]:
) -> Optional[List[PrimitiveType]]:
"""Excutes the HTTP request specified by the given RequestInformation and returns the
deserialized response model collection.
Args:
request_info (RequestInformation): the request info to execute.
response_type (ResponseType): the class of the response model
response_type (PrimitiveType): the class of the response model
to deserialize the response into.
error_map (Dict[str, type[ParsableFactory]]): the error dict to use in
case of a failed request.

Returns:
Optional[List[ResponseType]]: he deserialized response model collection.
Optional[List[PrimitiveType]]: The deserialized primitive type collection.
"""
parent_span = self.start_tracing_span(request_info, "send_collection_of_primitive_async")
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def get_child_node(self, field_name: str) -> Optional[ParseNode]:
return FormParseNode(self._fields[field_name])
return None

def get_collection_of_primitive_values(self, primitive_type: type) -> Optional[List[T]]:
def get_collection_of_primitive_values(self, primitive_type: type[T]) -> Optional[List[T]]:
"""Gets the collection of primitive values of the node
Args:
primitive_type: The type of primitive to return.
Expand All @@ -189,7 +189,7 @@ def get_collection_of_primitive_values(self, primitive_type: type) -> Optional[L
return result
raise Exception(f"Encountered an unknown type during deserialization {primitive_type}")

def get_collection_of_object_values(self, factory: ParsableFactory) -> Optional[List[U]]:
def get_collection_of_object_values(self, factory: ParsableFactory[U]) -> Optional[List[U]]:
raise Exception("Collection of object values is not supported with uri form encoding.")

def get_collection_of_enum_values(self, enum_class: K) -> Optional[List[K]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def get_time_value(self) -> Optional[time]:
return datetime_obj
return None

def get_collection_of_primitive_values(self, primitive_type: Any) -> Optional[List[T]]:
def get_collection_of_primitive_values(self, primitive_type: type[T]) -> Optional[List[T]]:
"""Gets the collection of primitive values of the node
Args:
primitive_type: The type of primitive to return.
Expand All @@ -161,7 +161,7 @@ def func(item):
return list(map(func, json.loads(self._json_node)))
return list(map(func, list(self._json_node)))

def get_collection_of_object_values(self, factory: ParsableFactory) -> Optional[List[U]]:
def get_collection_of_object_values(self, factory: ParsableFactory[U]) -> Optional[List[U]]:
"""Gets the collection of type U values from the json node
Returns:
List[U]: The collection of model object values of the node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def get_time_value(self) -> Optional[time]:
return datetime_obj.time()
return None

def get_collection_of_primitive_values(self, primitive_type) -> List[T]:
def get_collection_of_primitive_values(self, primitive_type: type[T]) -> List[T]:
"""Gets the collection of primitive values of the node
Args:
primitive_type: The type of primitive to return.
Expand All @@ -142,7 +142,7 @@ def get_collection_of_primitive_values(self, primitive_type) -> List[T]:
"""
raise Exception(self.NO_STRUCTURED_DATA_MESSAGE)

def get_collection_of_object_values(self, factory: ParsableFactory) -> List[U]:
def get_collection_of_object_values(self, factory: ParsableFactory[U]) -> List[U]:
"""Gets the collection of type U values from the text node
Returns:
List[U]: The collection of model object values of the node
Expand Down
Loading