mirror of
https://github.com/RGBCube/GitHubWrapper
synced 2025-05-31 13:08:12 +00:00
Refactor some typehints in http.py
This commit is contained in:
parent
6e114bc5c5
commit
d482cf1bbe
1 changed files with 47 additions and 39 deletions
|
@ -2,20 +2,19 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import TYPE_CHECKING, Dict, Union, List
|
||||
from typing import Dict, NamedTuple, Optional, Type, Union, List
|
||||
from typing_extensions import TypeAlias
|
||||
import platform
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .exceptions import *
|
||||
from .exceptions import GistNotFound, RepositoryAlreadyExists, MissingPermissions
|
||||
from .objects import APIObject, User, Gist, Repository, Organization, File
|
||||
from .objects import User, Gist, Repository, File
|
||||
from .urls import *
|
||||
from . import __version__
|
||||
|
||||
|
@ -26,7 +25,7 @@ __all__ = (
|
|||
|
||||
|
||||
LINK_PARSING_RE = re.compile(r"<(\S+(\S))>; rel=\"(\S+)\"")
|
||||
Rates = namedtuple('Rates', ('remaining', 'used', 'total', 'reset_when', 'last_request'))
|
||||
Rates = NamedTuple('Rates', 'remaining', 'used', 'total', 'reset_when', 'last_request')
|
||||
|
||||
# aiohttp request tracking / checking bits
|
||||
async def on_req_start(
|
||||
|
@ -58,6 +57,8 @@ trace_config = aiohttp.TraceConfig()
|
|||
trace_config.on_request_start.append(on_req_start)
|
||||
trace_config.on_request_end.append(on_req_end)
|
||||
|
||||
APIType: TypeAlias = Union[User, Gist, Repository]
|
||||
|
||||
async def make_session(*, headers: Dict[str, str], authorization: Union[aiohttp.BasicAuth, None]) -> aiohttp.ClientSession:
|
||||
"""This makes the ClientSession, attaching the trace config and ensuring a UA header is present."""
|
||||
if not headers.get('User-Agent'):
|
||||
|
@ -78,52 +79,56 @@ class Paginator:
|
|||
self.session = session
|
||||
self.response = response
|
||||
self.should_paginate = bool(self.response.headers.get('Link', False))
|
||||
types: Dict[str, APIObject] = {
|
||||
types: Dict[str, Type[APIType]] = { # note: the type checker doesnt see subclasses like that
|
||||
'user': User,
|
||||
'gist' : Gist,
|
||||
'repo' : Repository
|
||||
}
|
||||
self.target_type = types[target_type]
|
||||
self.target_type: Type[APIType] = types[target_type]
|
||||
self.pages = {}
|
||||
self.is_exhausted = False
|
||||
self.current_page = 1
|
||||
self.next_page = self.current_page + 1
|
||||
self.parse_header(response)
|
||||
|
||||
async def fetch_page(self, link) -> Dict[str, Union[str, int]]:
|
||||
async def fetch_page(self, link: str) -> Dict[str, Union[str, int]]:
|
||||
"""Fetches a specific page and returns the JSON."""
|
||||
return await (await self.session.get(link)).json()
|
||||
|
||||
async def early_return(self) -> List[APIObject]:
|
||||
async def early_return(self) -> List[APIType]:
|
||||
# I don't rightly remember what this does differently, may have a good ol redesign later
|
||||
return [self.target_type(data, self.session) for data in await self.response.json()]
|
||||
return [self.target_type(data, self) for data in await self.response.json()] # type: ignore
|
||||
|
||||
async def exhaust(self) -> List[APIObject]:
|
||||
async def exhaust(self) -> List[APIType]:
|
||||
"""Iterates through all of the pages for the relevant object and creates them."""
|
||||
if self.should_paginate:
|
||||
return await self.early_return()
|
||||
out = []
|
||||
|
||||
out: List[APIType] = []
|
||||
for page in range(1, self.max_page+1):
|
||||
result = await self.session.get(self.bare_link + str(page))
|
||||
out.extend([self.target_type(item, self.session) for item in await result.json()])
|
||||
out.extend([self.target_type(item, self) for item in await result.json()]) # type: ignore
|
||||
|
||||
self.is_exhausted = True
|
||||
return out
|
||||
|
||||
def parse_header(self, response: aiohttp.ClientResponse) -> None:
|
||||
"""Predicts wether a call will exceed the ratelimit ahead of the call."""
|
||||
header = response.headers.get('Link')
|
||||
header = response.headers['Link']
|
||||
groups = LINK_PARSING_RE.findall(header)
|
||||
self.max_page = int(groups[1][1])
|
||||
if int(response.headers['X-RateLimit-Remaining']) < self.max_page:
|
||||
raise WillExceedRatelimit(response, self.max_page)
|
||||
self.bare_link = groups[0][0][:-1]
|
||||
|
||||
GithubUserData = GithubRepoData = GithubIssueData = GithubOrgData = GithubGistData = Dict[str, Union [str, int]]
|
||||
# GithubUserData = GithubRepoData = GithubIssueData = GithubOrgData = GithubGistData = Dict[str, Union [str, int]]
|
||||
# Commentnig this out for now, consider using TypeDict's instead in the future <3
|
||||
|
||||
class http:
|
||||
def __init__(self, headers: Dict[str, Union[str, int]], auth: Union[aiohttp.BasicAuth, None]):
|
||||
def __init__(self, headers: Dict[str, Union[str, int]], auth: Union[aiohttp.BasicAuth, None]) -> None:
|
||||
if not headers.get('User-Agent'):
|
||||
headers['User-Agent'] = f'Github-API-Wrapper (https://github.com/VarMonke/Github-Api-Wrapper) @ {__version__} Python/{platform.python_version()} aiohttp/{aiohttp.__version__}'
|
||||
|
||||
self._rates = Rates('', '', '', '', '')
|
||||
self.headers = headers
|
||||
self.auth = auth
|
||||
|
@ -133,7 +138,7 @@ class http:
|
|||
|
||||
async def start(self):
|
||||
self.session = aiohttp.ClientSession(
|
||||
headers=self.headers,
|
||||
headers=self.headers, # type: ignore
|
||||
auth=self.auth,
|
||||
trace_configs=[trace_config],
|
||||
)
|
||||
|
@ -144,10 +149,10 @@ class http:
|
|||
def update_headers(self, *, flush: bool = False, new_headers: Dict[str, Union[str, int]]):
|
||||
if flush:
|
||||
from multidict import CIMultiDict
|
||||
self.session.headers = CIMultiDict(**new_headers)
|
||||
self.session._default_headers = CIMultiDict(**new_headers) # type: ignore
|
||||
else:
|
||||
self.session.headers = {**self.session.headers, **new_headers}
|
||||
|
||||
self.session._default_headers = {**self.session.headers, **new_headers} # type: ignore
|
||||
|
||||
async def update_auth(self, *, username: str, token: str):
|
||||
auth = aiohttp.BasicAuth(username, token)
|
||||
headers = self.session.headers
|
||||
|
@ -170,56 +175,59 @@ class http:
|
|||
await self.session.get(BASE_URL)
|
||||
return (datetime.utcnow() - start).total_seconds()
|
||||
|
||||
async def get_self(self) -> GithubUserData:
|
||||
async def get_self(self) -> Dict[str, Union [str, int]]:
|
||||
"""Returns the authenticated User's data"""
|
||||
result = await self.session.get(SELF_URL)
|
||||
if 200 <= result.status <= 299:
|
||||
return await result.json()
|
||||
raise InvalidToken
|
||||
|
||||
async def get_user(self, username: str) -> GithubUserData:
|
||||
async def get_user(self, username: str) -> Dict[str, Union [str, int]]:
|
||||
"""Returns a user's public data in JSON format."""
|
||||
result = await self.session.get(USERS_URL.format(username))
|
||||
if 200 <= result.status <= 299:
|
||||
return await result.json()
|
||||
raise UserNotFound
|
||||
|
||||
async def get_user_repos(self, _user: User) -> List[GithubRepoData]:
|
||||
async def get_user_repos(self, _user: User) -> List[Dict[str, Union [str, int]]]:
|
||||
result = await self.session.get(USER_REPOS_URL.format(_user.login))
|
||||
if 200 <= result.status <= 299:
|
||||
return await result.json()
|
||||
else:
|
||||
print('This shouldn\'t be reachable')
|
||||
|
||||
print('This shouldn\'t be reachable')
|
||||
return []
|
||||
|
||||
async def get_user_gists(self, _user: User) -> List[GithubGistData]:
|
||||
async def get_user_gists(self, _user: User) -> List[Dict[str, Union [str, int]]]:
|
||||
result = await self.session.get(USER_GISTS_URL.format(_user.login))
|
||||
if 200 <= result.status <= 299:
|
||||
return await result.json()
|
||||
else:
|
||||
print('This shouldn\'t be reachable')
|
||||
|
||||
print('This shouldn\'t be reachable')
|
||||
return []
|
||||
|
||||
async def get_user_orgs(self, _user: User) -> List[GithubOrgData]:
|
||||
async def get_user_orgs(self, _user: User) -> List[Dict[str, Union [str, int]]]:
|
||||
result = await self.session.get(USER_ORGS_URL.format(_user.login))
|
||||
if 200 <= result.status <= 299:
|
||||
return await result.json()
|
||||
else:
|
||||
print('This shouldn\'t be reachable')
|
||||
|
||||
print('This shouldn\'t be reachable')
|
||||
return []
|
||||
|
||||
async def get_repo(self, owner: str, repo_name: str) -> GithubRepoData:
|
||||
async def get_repo(self, owner: str, repo_name: str) -> Dict[str, Union [str, int]]:
|
||||
"""Returns a Repo's raw JSON from the given owner and repo name."""
|
||||
result = await self.session.get(REPO_URL.format(owner, repo_name))
|
||||
if 200 <= result.status <= 299:
|
||||
return await result.json()
|
||||
raise RepositoryNotFound
|
||||
|
||||
async def get_repo_issue(self, owner: str, repo_name: str, issue_number: int) -> GithubIssueData:
|
||||
async def get_repo_issue(self, owner: str, repo_name: str, issue_number: int) -> Dict[str, Union [str, int]]:
|
||||
"""Returns a single issue's JSON from the given owner and repo name."""
|
||||
result = await self.session.get(REPO_ISSUE_URL.format(owner, repo_name, issue_number))
|
||||
if 200 <= result.status <= 299:
|
||||
return await result.json()
|
||||
raise IssueNotFound
|
||||
|
||||
async def delete_repo(self, owner: str, repo_name: str) -> None:
|
||||
async def delete_repo(self, owner: str, repo_name: str) -> Optional[str]:
|
||||
"""Deletes a Repo from the given owner and repo name."""
|
||||
result = await self.session.delete(REPO_URL.format(owner, repo_name))
|
||||
if 204 <= result.status <= 299:
|
||||
|
@ -228,7 +236,7 @@ class http:
|
|||
raise MissingPermissions
|
||||
raise RepositoryNotFound
|
||||
|
||||
async def delete_gist(self, gist_id: int) -> None:
|
||||
async def delete_gist(self, gist_id: int) -> Optional[str]:
|
||||
"""Deletes a Gist from the given gist id."""
|
||||
result = await self.session.delete(GIST_URL.format(gist_id))
|
||||
if result.status == 204:
|
||||
|
@ -237,14 +245,14 @@ class http:
|
|||
raise MissingPermissions
|
||||
raise GistNotFound
|
||||
|
||||
async def get_org(self, org_name: str) -> GithubOrgData:
|
||||
async def get_org(self, org_name: str) -> Dict[str, Union [str, int]]:
|
||||
"""Returns an org's public data in JSON format."""
|
||||
result = await self.session.get(ORG_URL.format(org_name))
|
||||
if 200 <= result.status <= 299:
|
||||
return await result.json()
|
||||
raise OrganizationNotFound
|
||||
|
||||
async def get_gist(self, gist_id: int) -> GithubGistData:
|
||||
async def get_gist(self, gist_id: int) -> Dict[str, Union [str, int]]:
|
||||
"""Returns a gist's raw JSON from the given gist id."""
|
||||
result = await self.session.get(GIST_URL.format(gist_id))
|
||||
if 200 <= result.status <= 299:
|
||||
|
@ -257,7 +265,7 @@ class http:
|
|||
files: List['File'] = [],
|
||||
description: str = 'Default description',
|
||||
public: bool = False
|
||||
) -> GithubGistData:
|
||||
) -> Dict[str, Union [str, int]]:
|
||||
data = {}
|
||||
data['description'] = description
|
||||
data['public'] = public
|
||||
|
@ -274,7 +282,7 @@ class http:
|
|||
return await result.json()
|
||||
raise InvalidToken
|
||||
|
||||
async def create_repo(self, name: str, description: str, public: bool, gitignore: str, license: str) -> GithubRepoData:
|
||||
async def create_repo(self, name: str, description: str, public: bool, gitignore: Optional[str], license: Optional[str]) -> Dict[str, Union [str, int]]:
|
||||
"""Creates a repo for you with given data"""
|
||||
data = {
|
||||
'name': name,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue