Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions pokemon_v2/api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import re
import subprocess
from typing import Optional

from rest_framework import viewsets
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.serializers import BaseSerializer
from rest_framework.views import APIView
from django.shortcuts import get_object_or_404
from django.db.models import Model, Q, QuerySet
from django.http import Http404
from django.db.models import Q
from django.shortcuts import get_object_or_404
from drf_spectacular.utils import extend_schema, extend_schema_view, OpenApiParameter
from drf_spectacular.types import OpenApiTypes

Expand All @@ -25,9 +29,9 @@ class ListOrDetailSerialRelation:
for list or detail view.
"""

list_serializer_class = None
list_serializer_class: Optional[type[BaseSerializer]] = None

def get_serializer_class(self):
def get_serializer_class(self) -> type[BaseSerializer]:
if self.action == "list" and self.list_serializer_class is not None:
return self.list_serializer_class
return self.serializer_class
Expand All @@ -39,11 +43,11 @@ class NameOrIdRetrieval:
pk (in this case ID) or by name
"""

idPattern = re.compile(r"^-?[0-9]+$")
idPattern: re.Pattern[str] = re.compile(r"^-?[0-9]+$")
# Allow alphanumeric, hyphen, plus, and space (Space added for test cases using name for lookup, ex: 'base pkm')
namePattern = re.compile(r"^[0-9A-Za-z\-\+ ]+$")
namePattern: re.Pattern[str] = re.compile(r"^[0-9A-Za-z\-\+ ]+$")

def get_queryset(self):
def get_queryset(self) -> QuerySet:
queryset = super().get_queryset()
filter = self.request.GET.get("q", "")

Expand All @@ -52,7 +56,7 @@ def get_queryset(self):

return queryset

def get_object(self):
def get_object(self) -> Model:
queryset = self.get_queryset()
queryset = self.filter_queryset(queryset)
lookup = self.kwargs["pk"]
Expand Down Expand Up @@ -94,7 +98,7 @@ class PokeapiCommonViewset(
ListOrDetailSerialRelation, NameOrIdRetrieval, viewsets.ReadOnlyModelViewSet
):
@extend_schema(parameters=[retrieve_path_parameter])
def retrieve(self, request, pk=None):
def retrieve(self, request: Request, pk: Optional[str] = None) -> Response:
return super().retrieve(request, pk)

pass
Expand Down Expand Up @@ -978,7 +982,7 @@ class VersionGroupResource(PokeapiCommonViewset):
},
)
class PokemonEncounterView(APIView):
def get(self, request, pokemon_id):
def get(self, request: Request, pokemon_id: str) -> Response:
self.context = dict(request=request)

try:
Expand Down Expand Up @@ -1076,7 +1080,7 @@ def get(self, request, pokemon_id):
},
)
class PokeapiMetaViewset(viewsets.ViewSet):
def list(self, request):
def list(self, request: Request) -> Response:
try:
git_hash = (
subprocess.check_output(
Expand Down