diff --git a/Github/http.py b/Github/http.py index 45b0b07..6ff70ed 100644 --- a/Github/http.py +++ b/Github/http.py @@ -2,20 +2,19 @@ from __future__ import annotations -import io import json import re -from collections import namedtuple from datetime import datetime from types import SimpleNamespace -from typing import TYPE_CHECKING, Dict, Union, List +from typing import Dict, NamedTuple, Optional, Type, Union, List +from typing_extensions import TypeAlias import platform import aiohttp from .exceptions import * from .exceptions import GistNotFound, RepositoryAlreadyExists, MissingPermissions -from .objects import APIObject, User, Gist, Repository, Organization, File +from .objects import User, Gist, Repository, File from .urls import * from . import __version__ @@ -26,7 +25,7 @@ __all__ = ( LINK_PARSING_RE = re.compile(r"<(\S+(\S))>; rel=\"(\S+)\"") -Rates = namedtuple('Rates', ('remaining', 'used', 'total', 'reset_when', 'last_request')) +Rates = NamedTuple('Rates', 'remaining', 'used', 'total', 'reset_when', 'last_request') # aiohttp request tracking / checking bits async def on_req_start( @@ -58,6 +57,8 @@ trace_config = aiohttp.TraceConfig() trace_config.on_request_start.append(on_req_start) trace_config.on_request_end.append(on_req_end) +APIType: TypeAlias = Union[User, Gist, Repository] + async def make_session(*, headers: Dict[str, str], authorization: Union[aiohttp.BasicAuth, None]) -> aiohttp.ClientSession: """This makes the ClientSession, attaching the trace config and ensuring a UA header is present.""" if not headers.get('User-Agent'): @@ -78,52 +79,56 @@ class Paginator: self.session = session self.response = response self.should_paginate = bool(self.response.headers.get('Link', False)) - types: Dict[str, APIObject] = { + types: Dict[str, Type[APIType]] = { # note: the type checker doesnt see subclasses like that 'user': User, 'gist' : Gist, 'repo' : Repository } - self.target_type = types[target_type] + self.target_type: Type[APIType] = types[target_type] self.pages = {} self.is_exhausted = False self.current_page = 1 self.next_page = self.current_page + 1 self.parse_header(response) - async def fetch_page(self, link) -> Dict[str, Union[str, int]]: + async def fetch_page(self, link: str) -> Dict[str, Union[str, int]]: """Fetches a specific page and returns the JSON.""" return await (await self.session.get(link)).json() - async def early_return(self) -> List[APIObject]: + async def early_return(self) -> List[APIType]: # I don't rightly remember what this does differently, may have a good ol redesign later - return [self.target_type(data, self.session) for data in await self.response.json()] + return [self.target_type(data, self) for data in await self.response.json()] # type: ignore - async def exhaust(self) -> List[APIObject]: + async def exhaust(self) -> List[APIType]: """Iterates through all of the pages for the relevant object and creates them.""" if self.should_paginate: return await self.early_return() - out = [] + + out: List[APIType] = [] for page in range(1, self.max_page+1): result = await self.session.get(self.bare_link + str(page)) - out.extend([self.target_type(item, self.session) for item in await result.json()]) + out.extend([self.target_type(item, self) for item in await result.json()]) # type: ignore + self.is_exhausted = True return out def parse_header(self, response: aiohttp.ClientResponse) -> None: """Predicts wether a call will exceed the ratelimit ahead of the call.""" - header = response.headers.get('Link') + header = response.headers['Link'] groups = LINK_PARSING_RE.findall(header) self.max_page = int(groups[1][1]) if int(response.headers['X-RateLimit-Remaining']) < self.max_page: raise WillExceedRatelimit(response, self.max_page) self.bare_link = groups[0][0][:-1] -GithubUserData = GithubRepoData = GithubIssueData = GithubOrgData = GithubGistData = Dict[str, Union [str, int]] +# GithubUserData = GithubRepoData = GithubIssueData = GithubOrgData = GithubGistData = Dict[str, Union [str, int]] +# Commentnig this out for now, consider using TypeDict's instead in the future <3 class http: - def __init__(self, headers: Dict[str, Union[str, int]], auth: Union[aiohttp.BasicAuth, None]): + def __init__(self, headers: Dict[str, Union[str, int]], auth: Union[aiohttp.BasicAuth, None]) -> None: if not headers.get('User-Agent'): headers['User-Agent'] = f'Github-API-Wrapper (https://github.com/VarMonke/Github-Api-Wrapper) @ {__version__} Python/{platform.python_version()} aiohttp/{aiohttp.__version__}' + self._rates = Rates('', '', '', '', '') self.headers = headers self.auth = auth @@ -133,7 +138,7 @@ class http: async def start(self): self.session = aiohttp.ClientSession( - headers=self.headers, + headers=self.headers, # type: ignore auth=self.auth, trace_configs=[trace_config], ) @@ -144,10 +149,10 @@ class http: def update_headers(self, *, flush: bool = False, new_headers: Dict[str, Union[str, int]]): if flush: from multidict import CIMultiDict - self.session.headers = CIMultiDict(**new_headers) + self.session._default_headers = CIMultiDict(**new_headers) # type: ignore else: - self.session.headers = {**self.session.headers, **new_headers} - + self.session._default_headers = {**self.session.headers, **new_headers} # type: ignore + async def update_auth(self, *, username: str, token: str): auth = aiohttp.BasicAuth(username, token) headers = self.session.headers @@ -170,56 +175,59 @@ class http: await self.session.get(BASE_URL) return (datetime.utcnow() - start).total_seconds() - async def get_self(self) -> GithubUserData: + async def get_self(self) -> Dict[str, Union [str, int]]: """Returns the authenticated User's data""" result = await self.session.get(SELF_URL) if 200 <= result.status <= 299: return await result.json() raise InvalidToken - async def get_user(self, username: str) -> GithubUserData: + async def get_user(self, username: str) -> Dict[str, Union [str, int]]: """Returns a user's public data in JSON format.""" result = await self.session.get(USERS_URL.format(username)) if 200 <= result.status <= 299: return await result.json() raise UserNotFound - async def get_user_repos(self, _user: User) -> List[GithubRepoData]: + async def get_user_repos(self, _user: User) -> List[Dict[str, Union [str, int]]]: result = await self.session.get(USER_REPOS_URL.format(_user.login)) if 200 <= result.status <= 299: return await result.json() - else: - print('This shouldn\'t be reachable') + + print('This shouldn\'t be reachable') + return [] - async def get_user_gists(self, _user: User) -> List[GithubGistData]: + async def get_user_gists(self, _user: User) -> List[Dict[str, Union [str, int]]]: result = await self.session.get(USER_GISTS_URL.format(_user.login)) if 200 <= result.status <= 299: return await result.json() - else: - print('This shouldn\'t be reachable') + + print('This shouldn\'t be reachable') + return [] - async def get_user_orgs(self, _user: User) -> List[GithubOrgData]: + async def get_user_orgs(self, _user: User) -> List[Dict[str, Union [str, int]]]: result = await self.session.get(USER_ORGS_URL.format(_user.login)) if 200 <= result.status <= 299: return await result.json() - else: - print('This shouldn\'t be reachable') + + print('This shouldn\'t be reachable') + return [] - async def get_repo(self, owner: str, repo_name: str) -> GithubRepoData: + async def get_repo(self, owner: str, repo_name: str) -> Dict[str, Union [str, int]]: """Returns a Repo's raw JSON from the given owner and repo name.""" result = await self.session.get(REPO_URL.format(owner, repo_name)) if 200 <= result.status <= 299: return await result.json() raise RepositoryNotFound - async def get_repo_issue(self, owner: str, repo_name: str, issue_number: int) -> GithubIssueData: + async def get_repo_issue(self, owner: str, repo_name: str, issue_number: int) -> Dict[str, Union [str, int]]: """Returns a single issue's JSON from the given owner and repo name.""" result = await self.session.get(REPO_ISSUE_URL.format(owner, repo_name, issue_number)) if 200 <= result.status <= 299: return await result.json() raise IssueNotFound - async def delete_repo(self, owner: str, repo_name: str) -> None: + async def delete_repo(self, owner: str, repo_name: str) -> Optional[str]: """Deletes a Repo from the given owner and repo name.""" result = await self.session.delete(REPO_URL.format(owner, repo_name)) if 204 <= result.status <= 299: @@ -228,7 +236,7 @@ class http: raise MissingPermissions raise RepositoryNotFound - async def delete_gist(self, gist_id: int) -> None: + async def delete_gist(self, gist_id: int) -> Optional[str]: """Deletes a Gist from the given gist id.""" result = await self.session.delete(GIST_URL.format(gist_id)) if result.status == 204: @@ -237,14 +245,14 @@ class http: raise MissingPermissions raise GistNotFound - async def get_org(self, org_name: str) -> GithubOrgData: + async def get_org(self, org_name: str) -> Dict[str, Union [str, int]]: """Returns an org's public data in JSON format.""" result = await self.session.get(ORG_URL.format(org_name)) if 200 <= result.status <= 299: return await result.json() raise OrganizationNotFound - async def get_gist(self, gist_id: int) -> GithubGistData: + async def get_gist(self, gist_id: int) -> Dict[str, Union [str, int]]: """Returns a gist's raw JSON from the given gist id.""" result = await self.session.get(GIST_URL.format(gist_id)) if 200 <= result.status <= 299: @@ -257,7 +265,7 @@ class http: files: List['File'] = [], description: str = 'Default description', public: bool = False - ) -> GithubGistData: + ) -> Dict[str, Union [str, int]]: data = {} data['description'] = description data['public'] = public @@ -274,7 +282,7 @@ class http: return await result.json() raise InvalidToken - async def create_repo(self, name: str, description: str, public: bool, gitignore: str, license: str) -> GithubRepoData: + async def create_repo(self, name: str, description: str, public: bool, gitignore: Optional[str], license: Optional[str]) -> Dict[str, Union [str, int]]: """Creates a repo for you with given data""" data = { 'name': name,