Skip to content

Commit

Permalink
Merge pull request #402 from microsoft/andrueastman/fixTypingErrors
Browse files Browse the repository at this point in the history
fix: fixes typing issues discovered from github api generation
  • Loading branch information
andrueastman authored Nov 9, 2024
2 parents 72eb943 + 52598b3 commit 92cf4c5
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 29 deletions.
14 changes: 9 additions & 5 deletions packages/abstractions/kiota_abstractions/request_adapter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import datetime
from datetime import date, datetime, time, timedelta
from io import BytesIO
from typing import Dict, Generic, List, Optional, TypeVar, Union
from uuid import UUID

from .request_information import RequestInformation
from .serialization import Parsable, ParsableFactory, SerializationWriterFactory
Expand All @@ -11,6 +12,9 @@
ResponseType = TypeVar("ResponseType")
ModelType = TypeVar("ModelType", bound=Parsable)
RequestType = TypeVar("RequestType")
PrimitiveType = TypeVar(
"PrimitiveType", bool, str, int, float, UUID, datetime, timedelta, date, time, bytes
)


class RequestAdapter(ABC, Generic[RequestType]):
Expand Down Expand Up @@ -75,21 +79,21 @@ async def send_collection_async(
async def send_collection_of_primitive_async(
self,
request_info: RequestInformation,
response_type: ResponseType,
response_type: type[PrimitiveType],
error_map: Optional[Dict[str, type[ParsableFactory]]],
) -> Optional[List[ResponseType]]:
) -> Optional[List[PrimitiveType]]:
"""Excutes the HTTP request specified by the given RequestInformation and returns the
deserialized response model collection.
Args:
request_info (RequestInformation): the request info to execute.
response_type (ResponseType): the class of the response model to deserialize the
response_type (PrimitiveType): the class of the response model to deserialize the
response into.
error_map (Optional[Dict[str, type[ParsableFactory]]]): the error dict to use in
case of a failed request.
Returns:
Optional[List[ModelType]]: The deserialized response model collection.
Optional[List[PrimitiveType]]: The deserialized primitive collection.
"""
pass

Expand Down
17 changes: 6 additions & 11 deletions packages/abstractions/kiota_abstractions/request_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from .request_adapter import RequestAdapter

Url = str
T = TypeVar("T", bound=Parsable)
T = TypeVar("T", bool, str, int, float, UUID, datetime, timedelta, date, time, bytes)
U = TypeVar("U", bound=Parsable)
QueryParameters = TypeVar('QueryParameters')
OBSERVABILITY_TRACER_NAME = "microsoft-python-kiota-abstractions"
tracer = trace.get_tracer(OBSERVABILITY_TRACER_NAME, VERSION)
Expand Down Expand Up @@ -155,20 +156,20 @@ def set_content_from_parsable(
self,
request_adapter: RequestAdapter,
content_type: str,
values: Union[T, List[T]],
values: Union[U, List[U]],
) -> None:
"""Sets the request body from a model with the specified content type.
Args:
request_adapter (Optional[RequestAdapter]): The adapter service to get the serialization
writer from.
content_type (Optional[str]): the content type.
values (Union[T, List[T]]): the models.
values (Union[U, List[U]]): the models.
"""
with tracer.start_as_current_span(
self._create_parent_span_name("set_content_from_parsable")
) as span:
writer = self._get_serialization_writer(request_adapter, content_type, values, span)
writer = self._get_serialization_writer(request_adapter, content_type, span)
if isinstance(values, MultipartBody):
content_type += f"; boundary={values.boundary}"
values.request_adapter = request_adapter
Expand Down Expand Up @@ -198,7 +199,7 @@ def set_content_from_scalar(
with tracer.start_as_current_span(
self._create_parent_span_name("set_content_from_scalar")
) as span:
writer = self._get_serialization_writer(request_adapter, content_type, values, span)
writer = self._get_serialization_writer(request_adapter, content_type, span)

if isinstance(values, list):
writer.writer = writer.write_collection_of_primitive_values(None, values)
Expand Down Expand Up @@ -255,15 +256,13 @@ def _get_serialization_writer(
self,
request_adapter: Optional["RequestAdapter"],
content_type: Optional[str],
values: Union[T, List[T]],
parent_span: trace.Span,
):
"""_summary_
Args:
request_adapter (RequestAdapter): _description_
content_type (str): _description_
values (Union[T, List[T]]): _description_
"""
_span = self._start_local_tracing_span("_get_serialization_writer", parent_span)
try:
Expand All @@ -275,10 +274,6 @@ def _get_serialization_writer(
exc = ValueError("Content Type cannot be null")
_span.record_exception(exc)
raise exc
if not values:
exc = ValueError("Values cannot be null")
_span.record_exception(exc)
raise exc
return request_adapter.get_serialization_writer_factory(
).get_serialization_writer(content_type)
finally:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_time_value(self) -> Optional[time]:
pass

@abstractmethod
def get_collection_of_primitive_values(self, primitive_type) -> Optional[List[T]]:
def get_collection_of_primitive_values(self, primitive_type: type[T]) -> Optional[List[T]]:
"""Gets the collection of primitive values of the node
Args:
primitive_type: The type of primitive to return.
Expand All @@ -128,7 +128,7 @@ def get_collection_of_primitive_values(self, primitive_type) -> Optional[List[T]
pass

@abstractmethod
def get_collection_of_object_values(self, factory: ParsableFactory) -> Optional[List[U]]:
def get_collection_of_object_values(self, factory: ParsableFactory[U]) -> Optional[List[U]]:
"""Gets the collection of model object values of the node
Args:
factory (ParsableFactory): The factory to use to create the model object.
Expand Down
10 changes: 5 additions & 5 deletions packages/http/httpx/kiota_http/httpx_request_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from kiota_abstractions.api_error import APIError
from kiota_abstractions.authentication import AuthenticationProvider
from kiota_abstractions.request_adapter import RequestAdapter, ResponseType
from kiota_abstractions.request_adapter import RequestAdapter, ResponseType, PrimitiveType
from kiota_abstractions.request_information import RequestInformation
from kiota_abstractions.serialization import (
Parsable,
Expand Down Expand Up @@ -250,20 +250,20 @@ async def send_collection_async(
async def send_collection_of_primitive_async(
self,
request_info: RequestInformation,
response_type: ResponseType,
response_type: type[PrimitiveType],
error_map: Optional[Dict[str, type[ParsableFactory]]],
) -> Optional[List[ResponseType]]:
) -> Optional[List[PrimitiveType]]:
"""Excutes the HTTP request specified by the given RequestInformation and returns the
deserialized response model collection.
Args:
request_info (RequestInformation): the request info to execute.
response_type (ResponseType): the class of the response model
response_type (PrimitiveType): the class of the response model
to deserialize the response into.
error_map (Dict[str, type[ParsableFactory]]): the error dict to use in
case of a failed request.
Returns:
Optional[List[ResponseType]]: he deserialized response model collection.
Optional[List[PrimitiveType]]: The deserialized primitive type collection.
"""
parent_span = self.start_tracing_span(request_info, "send_collection_of_primitive_async")
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def get_child_node(self, field_name: str) -> Optional[ParseNode]:
return FormParseNode(self._fields[field_name])
return None

def get_collection_of_primitive_values(self, primitive_type: type) -> Optional[List[T]]:
def get_collection_of_primitive_values(self, primitive_type: type[T]) -> Optional[List[T]]:
"""Gets the collection of primitive values of the node
Args:
primitive_type: The type of primitive to return.
Expand All @@ -189,7 +189,7 @@ def get_collection_of_primitive_values(self, primitive_type: type) -> Optional[L
return result
raise Exception(f"Encountered an unknown type during deserialization {primitive_type}")

def get_collection_of_object_values(self, factory: ParsableFactory) -> Optional[List[U]]:
def get_collection_of_object_values(self, factory: ParsableFactory[U]) -> Optional[List[U]]:
raise Exception("Collection of object values is not supported with uri form encoding.")

def get_collection_of_enum_values(self, enum_class: K) -> Optional[List[K]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def get_time_value(self) -> Optional[time]:
return datetime_obj
return None

def get_collection_of_primitive_values(self, primitive_type: Any) -> Optional[List[T]]:
def get_collection_of_primitive_values(self, primitive_type: type[T]) -> Optional[List[T]]:
"""Gets the collection of primitive values of the node
Args:
primitive_type: The type of primitive to return.
Expand All @@ -161,7 +161,7 @@ def func(item):
return list(map(func, json.loads(self._json_node)))
return list(map(func, list(self._json_node)))

def get_collection_of_object_values(self, factory: ParsableFactory) -> Optional[List[U]]:
def get_collection_of_object_values(self, factory: ParsableFactory[U]) -> Optional[List[U]]:
"""Gets the collection of type U values from the json node
Returns:
List[U]: The collection of model object values of the node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def get_time_value(self) -> Optional[time]:
return datetime_obj.time()
return None

def get_collection_of_primitive_values(self, primitive_type) -> List[T]:
def get_collection_of_primitive_values(self, primitive_type: type[T]) -> List[T]:
"""Gets the collection of primitive values of the node
Args:
primitive_type: The type of primitive to return.
Expand All @@ -142,7 +142,7 @@ def get_collection_of_primitive_values(self, primitive_type) -> List[T]:
"""
raise Exception(self.NO_STRUCTURED_DATA_MESSAGE)

def get_collection_of_object_values(self, factory: ParsableFactory) -> List[U]:
def get_collection_of_object_values(self, factory: ParsableFactory[U]) -> List[U]:
"""Gets the collection of type U values from the text node
Returns:
List[U]: The collection of model object values of the node
Expand Down

0 comments on commit 92cf4c5

Please sign in to comment.