1
Fork 0
mirror of https://github.com/RGBCube/GitHubWrapper synced 2025-05-17 06:25:10 +00:00
GitHubWrapper/github/http.py
2022-05-11 11:33:22 +05:30

326 lines
13 KiB
Python

# == http.py ==#
from __future__ import annotations
from asyncio.base_subprocess import ReadSubprocessPipeProto
from base64 import b64encode
import json
import re
from datetime import datetime
from types import SimpleNamespace
from typing import Any, Dict, Literal, NamedTuple, Optional, Type, Tuple, Union, List
from typing_extensions import TypeAlias, reveal_type
import platform
import aiohttp
from .exceptions import *
from .exceptions import GistNotFound, RepositoryAlreadyExists, MissingPermissions
from .exceptions import FileAlreadyExists
from .exceptions import ResourceAlreadyExists
from .objects import User, Gist, Repository, File, bytes_to_b64
from .urls import *
from . import __version__
__all__: Tuple[str, ...] = (
'Paginator',
'http',
)
LINK_PARSING_RE = re.compile(r"<(\S+(\S))>; rel=\"(\S+)\"")
class Rates(NamedTuple):
remaining: str
used: str
total: str
reset_when: Union[datetime, str]
last_request: Union[datetime, str]
# aiohttp request tracking / checking bits
async def on_req_start(
session: aiohttp.ClientSession, ctx: SimpleNamespace, params: aiohttp.TraceRequestStartParams
) -> None:
"""Before-request hook to make sure we don't overrun the ratelimit."""
# print(repr(session), repr(ctx), repr(params))
pass
async def on_req_end(session: aiohttp.ClientSession, ctx: SimpleNamespace, params: aiohttp.TraceRequestEndParams) -> None:
"""After-request hook to adjust remaining requests on this time frame."""
headers = params.response.headers
remaining = headers['X-RateLimit-Remaining']
used = headers['X-RateLimit-Used']
total = headers['X-RateLimit-Limit']
reset_when = datetime.fromtimestamp(int(headers['X-RateLimit-Reset']))
last_req = datetime.utcnow()
session._rates = Rates(remaining, used, total, reset_when, last_req)
trace_config = aiohttp.TraceConfig()
trace_config.on_request_start.append(on_req_start)
trace_config.on_request_end.append(on_req_end)
APIType: TypeAlias = Union[User, Gist, Repository]
async def make_session(*, headers: Dict[str, str], authorization: Union[aiohttp.BasicAuth, None]) -> aiohttp.ClientSession:
"""This makes the ClientSession, attaching the trace config and ensuring a UA header is present."""
if not headers.get('User-Agent'):
headers[
'User-Agent'
] = f'Github-API-Wrapper (https://github.com/VarMonke/Github-Api-Wrapper) @ {__version__} Python {platform.python_version()} aiohttp {aiohttp.__version__}'
session = aiohttp.ClientSession(auth=authorization, headers=headers, trace_configs=[trace_config])
session._rates = Rates('', '', '', '', '')
return session
# pagination
class Paginator:
"""This class handles pagination for objects like Repos and Orgs."""
def __init__(self, session: aiohttp.ClientSession, response: aiohttp.ClientResponse, target_type: str):
self.session = session
self.response = response
self.should_paginate = bool(self.response.headers.get('Link', False))
types: Dict[str, Type[APIType]] = { # note: the type checker doesnt see subclasses like that
'user': User,
'gist': Gist,
'repo': Repository,
}
self.target_type: Type[APIType] = types[target_type]
self.pages = {}
self.is_exhausted = False
self.current_page = 1
self.next_page = self.current_page + 1
self.parse_header(response)
async def fetch_page(self, link: str) -> Dict[str, Union[str, int]]:
"""Fetches a specific page and returns the JSON."""
return await (await self.session.get(link)).json()
async def early_return(self) -> List[APIType]:
# I don't rightly remember what this does differently, may have a good ol redesign later
return [self.target_type(data, self) for data in await self.response.json()] # type: ignore
async def exhaust(self) -> List[APIType]:
"""Iterates through all of the pages for the relevant object and creates them."""
if self.should_paginate:
return await self.early_return()
out: List[APIType] = []
for page in range(1, self.max_page + 1):
result = await self.session.get(self.bare_link + str(page))
out.extend([self.target_type(item, self) for item in await result.json()]) # type: ignore
self.is_exhausted = True
return out
def parse_header(self, response: aiohttp.ClientResponse) -> None:
"""Predicts wether a call will exceed the ratelimit ahead of the call."""
header = response.headers['Link']
groups = LINK_PARSING_RE.findall(header)
self.max_page = int(groups[1][1])
if int(response.headers['X-RateLimit-Remaining']) < self.max_page:
raise WillExceedRatelimit(response, self.max_page)
self.bare_link = groups[0][0][:-1]
# GithubUserData = GithubRepoData = GithubIssueData = GithubOrgData = GithubGistData = Dict[str, Union [str, int]]
# Commentnig this out for now, consider using TypeDict's instead in the future <3
class http:
def __init__(self, headers: Dict[str, Union[str, int]], auth: Union[aiohttp.BasicAuth, None]) -> None:
if not headers.get('User-Agent'):
headers[
'User-Agent'
] = f'Github-API-Wrapper (https://github.com/VarMonke/Github-Api-Wrapper) @ {__version__} Python/{platform.python_version()} aiohttp/{aiohttp.__version__}'
self._rates = Rates('', '', '', '', '')
self.headers = headers
self.auth = auth
def __await__(self):
return self.start().__await__()
async def start(self):
self.session = aiohttp.ClientSession(
headers=self.headers, # type: ignore
auth=self.auth,
trace_configs=[trace_config],
)
if not hasattr(self.session, "_rates"):
self.session._rates = Rates('', '', '', '', '')
return self
def update_headers(self, *, flush: bool = False, new_headers: Dict[str, Union[str, int]]):
if flush:
from multidict import CIMultiDict
self.session._default_headers = CIMultiDict(**new_headers) # type: ignore
else:
self.session._default_headers = {**self.session.headers, **new_headers} # type: ignore
async def update_auth(self, *, username: str, token: str):
auth = aiohttp.BasicAuth(username, token)
headers = self.session.headers
config = self.session.trace_configs
await self.session.close()
self.session = aiohttp.ClientSession(headers=headers, auth=auth, trace_configs=config)
def data(self):
# return session headers and auth
headers = {**self.session.headers}
return {'headers': headers, 'auth': self.auth}
async def latency(self):
"""Returns the latency of the current session."""
start = datetime.utcnow()
await self.session.get(BASE_URL)
return (datetime.utcnow() - start).total_seconds()
async def get_self(self) -> Dict[str, Union[str, int]]:
"""Returns the authenticated User's data"""
result = await self.session.get(SELF_URL)
if 200 <= result.status <= 299:
return await result.json()
raise InvalidToken
async def get_user(self, username: str) -> Dict[str, Union[str, int]]:
"""Returns a user's public data in JSON format."""
result = await self.session.get(USERS_URL.format(username))
if 200 <= result.status <= 299:
return await result.json()
raise UserNotFound
async def get_user_repos(self, _user: User) -> List[Dict[str, Union[str, int]]]:
result = await self.session.get(USER_REPOS_URL.format(_user.login))
if 200 <= result.status <= 299:
return await result.json()
print('This shouldn\'t be reachable')
return []
async def get_user_gists(self, _user: User) -> List[Dict[str, Union[str, int]]]:
result = await self.session.get(USER_GISTS_URL.format(_user.login))
if 200 <= result.status <= 299:
return await result.json()
print('This shouldn\'t be reachable')
return []
async def get_user_orgs(self, _user: User) -> List[Dict[str, Union[str, int]]]:
result = await self.session.get(USER_ORGS_URL.format(_user.login))
if 200 <= result.status <= 299:
return await result.json()
print('This shouldn\'t be reachable')
return []
async def get_repo(self, owner: str, repo_name: str) -> Optional[Dict[str, Union[str, int]]]:
"""Returns a Repo's raw JSON from the given owner and repo name."""
result = await self.session.get(REPO_URL.format(owner, repo_name))
if 200 <= result.status <= 299:
return await result.json()
raise RepositoryNotFound
async def get_repo_issue(self, owner: str, repo_name: str, issue_number: int) -> Optional[Dict[str, Any]]:
"""Returns a single issue's JSON from the given owner and repo name."""
result = await self.session.get(REPO_ISSUE_URL.format(owner, repo_name, issue_number))
if 200 <= result.status <= 299:
return await result.json()
raise IssueNotFound
async def delete_repo(self, owner: Optional[str], repo_name: str) -> Optional[str]:
"""Deletes a Repo from the given owner and repo name."""
result = await self.session.delete(REPO_URL.format(owner, repo_name))
if 204 <= result.status <= 299:
return 'Successfully deleted repository.'
if result.status == 403: # type: ignore
raise MissingPermissions
raise RepositoryNotFound
async def delete_gist(self, gist_id: Union[str, int]) -> Optional[str]:
"""Deletes a Gist from the given gist id."""
result = await self.session.delete(GIST_URL.format(gist_id))
if result.status == 204:
return 'Successfully deleted gist.'
if result.status == 403:
raise MissingPermissions
raise GistNotFound
async def get_org(self, org_name: str) -> Dict[str, Union[str, int]]:
"""Returns an org's public data in JSON format.""" # type: ignore
result = await self.session.get(ORG_URL.format(org_name))
if 200 <= result.status <= 299:
return await result.json()
raise OrganizationNotFound
async def get_gist(self, gist_id: str) -> Dict[str, Union[str, int]]:
"""Returns a gist's raw JSON from the given gist id."""
result = await self.session.get(GIST_URL.format(gist_id))
if 200 <= result.status <= 299:
return await result.json()
raise GistNotFound
async def create_gist(
self, *, files: List['File'] = [], description: str = 'Default description', public: bool = False
) -> Dict[str, Union[str, int]]:
data = {}
data['description'] = description
data['public'] = public
data['files'] = {}
for file in files:
data['files'][file.filename] = {'filename': file.filename, 'content': file.read()} # helps editing the file
data = json.dumps(data)
_headers = dict(self.session.headers)
result = await self.session.post(
CREATE_GIST_URL, data=data, headers=_headers | {'Accept': 'application/vnd.github.v3+json'}
)
if 201 == result.status:
return await result.json()
raise InvalidToken
async def create_repo(
self, name: str, description: str, public: bool, gitignore: Optional[str], license: Optional[str]
) -> Dict[str, Union[str, int]]:
"""Creates a repo for you with given data"""
data = {
'name': name,
'description': description,
'public': public,
'gitignore_template': gitignore,
'license': license,
}
result = await self.session.post(CREATE_REPO_URL, data=json.dumps(data))
if 200 <= result.status <= 299:
return await result.json()
if result.status == 401:
raise NoAuthProvided
raise RepositoryAlreadyExists
async def add_file(self, owner: str, repo_name: str, filename: str, content: str, message: str, branch: str):
"""Adds a file to the given repo."""
data = {
'content': bytes_to_b64(content=content),
'message': message,
'branch': branch,
}
result = await self.session.put(ADD_FILE_URL.format(owner, repo_name, filename), data=json.dumps(data))
if 200 <= result.status <= 299:
return await result.json()
if result.status == 401:
raise NoAuthProvided
if result.status == 409:
raise FileAlreadyExists
if result.status == 422:
raise FileAlreadyExists('This file exists, and can only be edited.')
return await result.json(), result.status