1
Fork 0
mirror of https://github.com/RGBCube/GitHubWrapper synced 2025-07-25 07:27:43 +00:00

Refactor some typehints in http.py

This commit is contained in:
NextChai 2022-04-30 02:25:41 -04:00
parent 6e114bc5c5
commit d482cf1bbe

View file

@ -2,20 +2,19 @@
from __future__ import annotations from __future__ import annotations
import io
import json import json
import re import re
from collections import namedtuple
from datetime import datetime from datetime import datetime
from types import SimpleNamespace from types import SimpleNamespace
from typing import TYPE_CHECKING, Dict, Union, List from typing import Dict, NamedTuple, Optional, Type, Union, List
from typing_extensions import TypeAlias
import platform import platform
import aiohttp import aiohttp
from .exceptions import * from .exceptions import *
from .exceptions import GistNotFound, RepositoryAlreadyExists, MissingPermissions from .exceptions import GistNotFound, RepositoryAlreadyExists, MissingPermissions
from .objects import APIObject, User, Gist, Repository, Organization, File from .objects import User, Gist, Repository, File
from .urls import * from .urls import *
from . import __version__ from . import __version__
@ -26,7 +25,7 @@ __all__ = (
LINK_PARSING_RE = re.compile(r"<(\S+(\S))>; rel=\"(\S+)\"") LINK_PARSING_RE = re.compile(r"<(\S+(\S))>; rel=\"(\S+)\"")
Rates = namedtuple('Rates', ('remaining', 'used', 'total', 'reset_when', 'last_request')) Rates = NamedTuple('Rates', 'remaining', 'used', 'total', 'reset_when', 'last_request')
# aiohttp request tracking / checking bits # aiohttp request tracking / checking bits
async def on_req_start( async def on_req_start(
@ -58,6 +57,8 @@ trace_config = aiohttp.TraceConfig()
trace_config.on_request_start.append(on_req_start) trace_config.on_request_start.append(on_req_start)
trace_config.on_request_end.append(on_req_end) trace_config.on_request_end.append(on_req_end)
APIType: TypeAlias = Union[User, Gist, Repository]
async def make_session(*, headers: Dict[str, str], authorization: Union[aiohttp.BasicAuth, None]) -> aiohttp.ClientSession: async def make_session(*, headers: Dict[str, str], authorization: Union[aiohttp.BasicAuth, None]) -> aiohttp.ClientSession:
"""This makes the ClientSession, attaching the trace config and ensuring a UA header is present.""" """This makes the ClientSession, attaching the trace config and ensuring a UA header is present."""
if not headers.get('User-Agent'): if not headers.get('User-Agent'):
@ -78,52 +79,56 @@ class Paginator:
self.session = session self.session = session
self.response = response self.response = response
self.should_paginate = bool(self.response.headers.get('Link', False)) self.should_paginate = bool(self.response.headers.get('Link', False))
types: Dict[str, APIObject] = { types: Dict[str, Type[APIType]] = { # note: the type checker doesnt see subclasses like that
'user': User, 'user': User,
'gist' : Gist, 'gist' : Gist,
'repo' : Repository 'repo' : Repository
} }
self.target_type = types[target_type] self.target_type: Type[APIType] = types[target_type]
self.pages = {} self.pages = {}
self.is_exhausted = False self.is_exhausted = False
self.current_page = 1 self.current_page = 1
self.next_page = self.current_page + 1 self.next_page = self.current_page + 1
self.parse_header(response) self.parse_header(response)
async def fetch_page(self, link) -> Dict[str, Union[str, int]]: async def fetch_page(self, link: str) -> Dict[str, Union[str, int]]:
"""Fetches a specific page and returns the JSON.""" """Fetches a specific page and returns the JSON."""
return await (await self.session.get(link)).json() return await (await self.session.get(link)).json()
async def early_return(self) -> List[APIObject]: async def early_return(self) -> List[APIType]:
# I don't rightly remember what this does differently, may have a good ol redesign later # I don't rightly remember what this does differently, may have a good ol redesign later
return [self.target_type(data, self.session) for data in await self.response.json()] return [self.target_type(data, self) for data in await self.response.json()] # type: ignore
async def exhaust(self) -> List[APIObject]: async def exhaust(self) -> List[APIType]:
"""Iterates through all of the pages for the relevant object and creates them.""" """Iterates through all of the pages for the relevant object and creates them."""
if self.should_paginate: if self.should_paginate:
return await self.early_return() return await self.early_return()
out = []
out: List[APIType] = []
for page in range(1, self.max_page+1): for page in range(1, self.max_page+1):
result = await self.session.get(self.bare_link + str(page)) result = await self.session.get(self.bare_link + str(page))
out.extend([self.target_type(item, self.session) for item in await result.json()]) out.extend([self.target_type(item, self) for item in await result.json()]) # type: ignore
self.is_exhausted = True self.is_exhausted = True
return out return out
def parse_header(self, response: aiohttp.ClientResponse) -> None: def parse_header(self, response: aiohttp.ClientResponse) -> None:
"""Predicts wether a call will exceed the ratelimit ahead of the call.""" """Predicts wether a call will exceed the ratelimit ahead of the call."""
header = response.headers.get('Link') header = response.headers['Link']
groups = LINK_PARSING_RE.findall(header) groups = LINK_PARSING_RE.findall(header)
self.max_page = int(groups[1][1]) self.max_page = int(groups[1][1])
if int(response.headers['X-RateLimit-Remaining']) < self.max_page: if int(response.headers['X-RateLimit-Remaining']) < self.max_page:
raise WillExceedRatelimit(response, self.max_page) raise WillExceedRatelimit(response, self.max_page)
self.bare_link = groups[0][0][:-1] self.bare_link = groups[0][0][:-1]
GithubUserData = GithubRepoData = GithubIssueData = GithubOrgData = GithubGistData = Dict[str, Union [str, int]] # GithubUserData = GithubRepoData = GithubIssueData = GithubOrgData = GithubGistData = Dict[str, Union [str, int]]
# Commentnig this out for now, consider using TypeDict's instead in the future <3
class http: class http:
def __init__(self, headers: Dict[str, Union[str, int]], auth: Union[aiohttp.BasicAuth, None]): def __init__(self, headers: Dict[str, Union[str, int]], auth: Union[aiohttp.BasicAuth, None]) -> None:
if not headers.get('User-Agent'): if not headers.get('User-Agent'):
headers['User-Agent'] = f'Github-API-Wrapper (https://github.com/VarMonke/Github-Api-Wrapper) @ {__version__} Python/{platform.python_version()} aiohttp/{aiohttp.__version__}' headers['User-Agent'] = f'Github-API-Wrapper (https://github.com/VarMonke/Github-Api-Wrapper) @ {__version__} Python/{platform.python_version()} aiohttp/{aiohttp.__version__}'
self._rates = Rates('', '', '', '', '') self._rates = Rates('', '', '', '', '')
self.headers = headers self.headers = headers
self.auth = auth self.auth = auth
@ -133,7 +138,7 @@ class http:
async def start(self): async def start(self):
self.session = aiohttp.ClientSession( self.session = aiohttp.ClientSession(
headers=self.headers, headers=self.headers, # type: ignore
auth=self.auth, auth=self.auth,
trace_configs=[trace_config], trace_configs=[trace_config],
) )
@ -144,10 +149,10 @@ class http:
def update_headers(self, *, flush: bool = False, new_headers: Dict[str, Union[str, int]]): def update_headers(self, *, flush: bool = False, new_headers: Dict[str, Union[str, int]]):
if flush: if flush:
from multidict import CIMultiDict from multidict import CIMultiDict
self.session.headers = CIMultiDict(**new_headers) self.session._default_headers = CIMultiDict(**new_headers) # type: ignore
else: else:
self.session.headers = {**self.session.headers, **new_headers} self.session._default_headers = {**self.session.headers, **new_headers} # type: ignore
async def update_auth(self, *, username: str, token: str): async def update_auth(self, *, username: str, token: str):
auth = aiohttp.BasicAuth(username, token) auth = aiohttp.BasicAuth(username, token)
headers = self.session.headers headers = self.session.headers
@ -170,56 +175,59 @@ class http:
await self.session.get(BASE_URL) await self.session.get(BASE_URL)
return (datetime.utcnow() - start).total_seconds() return (datetime.utcnow() - start).total_seconds()
async def get_self(self) -> GithubUserData: async def get_self(self) -> Dict[str, Union [str, int]]:
"""Returns the authenticated User's data""" """Returns the authenticated User's data"""
result = await self.session.get(SELF_URL) result = await self.session.get(SELF_URL)
if 200 <= result.status <= 299: if 200 <= result.status <= 299:
return await result.json() return await result.json()
raise InvalidToken raise InvalidToken
async def get_user(self, username: str) -> GithubUserData: async def get_user(self, username: str) -> Dict[str, Union [str, int]]:
"""Returns a user's public data in JSON format.""" """Returns a user's public data in JSON format."""
result = await self.session.get(USERS_URL.format(username)) result = await self.session.get(USERS_URL.format(username))
if 200 <= result.status <= 299: if 200 <= result.status <= 299:
return await result.json() return await result.json()
raise UserNotFound raise UserNotFound
async def get_user_repos(self, _user: User) -> List[GithubRepoData]: async def get_user_repos(self, _user: User) -> List[Dict[str, Union [str, int]]]:
result = await self.session.get(USER_REPOS_URL.format(_user.login)) result = await self.session.get(USER_REPOS_URL.format(_user.login))
if 200 <= result.status <= 299: if 200 <= result.status <= 299:
return await result.json() return await result.json()
else:
print('This shouldn\'t be reachable') print('This shouldn\'t be reachable')
return []
async def get_user_gists(self, _user: User) -> List[GithubGistData]: async def get_user_gists(self, _user: User) -> List[Dict[str, Union [str, int]]]:
result = await self.session.get(USER_GISTS_URL.format(_user.login)) result = await self.session.get(USER_GISTS_URL.format(_user.login))
if 200 <= result.status <= 299: if 200 <= result.status <= 299:
return await result.json() return await result.json()
else:
print('This shouldn\'t be reachable') print('This shouldn\'t be reachable')
return []
async def get_user_orgs(self, _user: User) -> List[GithubOrgData]: async def get_user_orgs(self, _user: User) -> List[Dict[str, Union [str, int]]]:
result = await self.session.get(USER_ORGS_URL.format(_user.login)) result = await self.session.get(USER_ORGS_URL.format(_user.login))
if 200 <= result.status <= 299: if 200 <= result.status <= 299:
return await result.json() return await result.json()
else:
print('This shouldn\'t be reachable') print('This shouldn\'t be reachable')
return []
async def get_repo(self, owner: str, repo_name: str) -> GithubRepoData: async def get_repo(self, owner: str, repo_name: str) -> Dict[str, Union [str, int]]:
"""Returns a Repo's raw JSON from the given owner and repo name.""" """Returns a Repo's raw JSON from the given owner and repo name."""
result = await self.session.get(REPO_URL.format(owner, repo_name)) result = await self.session.get(REPO_URL.format(owner, repo_name))
if 200 <= result.status <= 299: if 200 <= result.status <= 299:
return await result.json() return await result.json()
raise RepositoryNotFound raise RepositoryNotFound
async def get_repo_issue(self, owner: str, repo_name: str, issue_number: int) -> GithubIssueData: async def get_repo_issue(self, owner: str, repo_name: str, issue_number: int) -> Dict[str, Union [str, int]]:
"""Returns a single issue's JSON from the given owner and repo name.""" """Returns a single issue's JSON from the given owner and repo name."""
result = await self.session.get(REPO_ISSUE_URL.format(owner, repo_name, issue_number)) result = await self.session.get(REPO_ISSUE_URL.format(owner, repo_name, issue_number))
if 200 <= result.status <= 299: if 200 <= result.status <= 299:
return await result.json() return await result.json()
raise IssueNotFound raise IssueNotFound
async def delete_repo(self, owner: str, repo_name: str) -> None: async def delete_repo(self, owner: str, repo_name: str) -> Optional[str]:
"""Deletes a Repo from the given owner and repo name.""" """Deletes a Repo from the given owner and repo name."""
result = await self.session.delete(REPO_URL.format(owner, repo_name)) result = await self.session.delete(REPO_URL.format(owner, repo_name))
if 204 <= result.status <= 299: if 204 <= result.status <= 299:
@ -228,7 +236,7 @@ class http:
raise MissingPermissions raise MissingPermissions
raise RepositoryNotFound raise RepositoryNotFound
async def delete_gist(self, gist_id: int) -> None: async def delete_gist(self, gist_id: int) -> Optional[str]:
"""Deletes a Gist from the given gist id.""" """Deletes a Gist from the given gist id."""
result = await self.session.delete(GIST_URL.format(gist_id)) result = await self.session.delete(GIST_URL.format(gist_id))
if result.status == 204: if result.status == 204:
@ -237,14 +245,14 @@ class http:
raise MissingPermissions raise MissingPermissions
raise GistNotFound raise GistNotFound
async def get_org(self, org_name: str) -> GithubOrgData: async def get_org(self, org_name: str) -> Dict[str, Union [str, int]]:
"""Returns an org's public data in JSON format.""" """Returns an org's public data in JSON format."""
result = await self.session.get(ORG_URL.format(org_name)) result = await self.session.get(ORG_URL.format(org_name))
if 200 <= result.status <= 299: if 200 <= result.status <= 299:
return await result.json() return await result.json()
raise OrganizationNotFound raise OrganizationNotFound
async def get_gist(self, gist_id: int) -> GithubGistData: async def get_gist(self, gist_id: int) -> Dict[str, Union [str, int]]:
"""Returns a gist's raw JSON from the given gist id.""" """Returns a gist's raw JSON from the given gist id."""
result = await self.session.get(GIST_URL.format(gist_id)) result = await self.session.get(GIST_URL.format(gist_id))
if 200 <= result.status <= 299: if 200 <= result.status <= 299:
@ -257,7 +265,7 @@ class http:
files: List['File'] = [], files: List['File'] = [],
description: str = 'Default description', description: str = 'Default description',
public: bool = False public: bool = False
) -> GithubGistData: ) -> Dict[str, Union [str, int]]:
data = {} data = {}
data['description'] = description data['description'] = description
data['public'] = public data['public'] = public
@ -274,7 +282,7 @@ class http:
return await result.json() return await result.json()
raise InvalidToken raise InvalidToken
async def create_repo(self, name: str, description: str, public: bool, gitignore: str, license: str) -> GithubRepoData: async def create_repo(self, name: str, description: str, public: bool, gitignore: Optional[str], license: Optional[str]) -> Dict[str, Union [str, int]]:
"""Creates a repo for you with given data""" """Creates a repo for you with given data"""
data = { data = {
'name': name, 'name': name,