diff --git a/pokemon_v2/api.py b/pokemon_v2/api.py index cf0b7052c..622d17d48 100644 --- a/pokemon_v2/api.py +++ b/pokemon_v2/api.py @@ -1,11 +1,15 @@ import re import subprocess +from typing import Optional + from rest_framework import viewsets +from rest_framework.request import Request from rest_framework.response import Response +from rest_framework.serializers import BaseSerializer from rest_framework.views import APIView -from django.shortcuts import get_object_or_404 +from django.db.models import Model, Q, QuerySet from django.http import Http404 -from django.db.models import Q +from django.shortcuts import get_object_or_404 from drf_spectacular.utils import extend_schema, extend_schema_view, OpenApiParameter from drf_spectacular.types import OpenApiTypes @@ -25,9 +29,9 @@ class ListOrDetailSerialRelation: for list or detail view. """ - list_serializer_class = None + list_serializer_class: Optional[type[BaseSerializer]] = None - def get_serializer_class(self): + def get_serializer_class(self) -> type[BaseSerializer]: if self.action == "list" and self.list_serializer_class is not None: return self.list_serializer_class return self.serializer_class @@ -39,11 +43,11 @@ class NameOrIdRetrieval: pk (in this case ID) or by name """ - idPattern = re.compile(r"^-?[0-9]+$") + idPattern: re.Pattern[str] = re.compile(r"^-?[0-9]+$") # Allow alphanumeric, hyphen, plus, and space (Space added for test cases using name for lookup, ex: 'base pkm') - namePattern = re.compile(r"^[0-9A-Za-z\-\+ ]+$") + namePattern: re.Pattern[str] = re.compile(r"^[0-9A-Za-z\-\+ ]+$") - def get_queryset(self): + def get_queryset(self) -> QuerySet: queryset = super().get_queryset() filter = self.request.GET.get("q", "") @@ -52,7 +56,7 @@ def get_queryset(self): return queryset - def get_object(self): + def get_object(self) -> Model: queryset = self.get_queryset() queryset = self.filter_queryset(queryset) lookup = self.kwargs["pk"] @@ -94,7 +98,7 @@ class PokeapiCommonViewset( ListOrDetailSerialRelation, NameOrIdRetrieval, viewsets.ReadOnlyModelViewSet ): @extend_schema(parameters=[retrieve_path_parameter]) - def retrieve(self, request, pk=None): + def retrieve(self, request: Request, pk: Optional[str] = None) -> Response: return super().retrieve(request, pk) pass @@ -978,7 +982,7 @@ class VersionGroupResource(PokeapiCommonViewset): }, ) class PokemonEncounterView(APIView): - def get(self, request, pokemon_id): + def get(self, request: Request, pokemon_id: str) -> Response: self.context = dict(request=request) try: @@ -1076,7 +1080,7 @@ def get(self, request, pokemon_id): }, ) class PokeapiMetaViewset(viewsets.ViewSet): - def list(self, request): + def list(self, request: Request) -> Response: try: git_hash = ( subprocess.check_output(