mirror of
https://github.com/RGBCube/GitHubWrapper
synced 2025-05-17 06:25:10 +00:00
326 lines
13 KiB
Python
326 lines
13 KiB
Python
# == http.py ==#
|
|
|
|
from __future__ import annotations
|
|
from asyncio.base_subprocess import ReadSubprocessPipeProto
|
|
from base64 import b64encode
|
|
|
|
import json
|
|
import re
|
|
from datetime import datetime
|
|
from types import SimpleNamespace
|
|
from typing import Any, Dict, Literal, NamedTuple, Optional, Type, Tuple, Union, List
|
|
from typing_extensions import TypeAlias, reveal_type
|
|
import platform
|
|
|
|
import aiohttp
|
|
|
|
from .exceptions import *
|
|
from .exceptions import GistNotFound, RepositoryAlreadyExists, MissingPermissions
|
|
from .exceptions import FileAlreadyExists
|
|
from .exceptions import ResourceAlreadyExists
|
|
from .objects import User, Gist, Repository, File, bytes_to_b64
|
|
from .urls import *
|
|
from . import __version__
|
|
|
|
__all__: Tuple[str, ...] = (
|
|
'Paginator',
|
|
'http',
|
|
)
|
|
|
|
|
|
LINK_PARSING_RE = re.compile(r"<(\S+(\S))>; rel=\"(\S+)\"")
|
|
|
|
|
|
class Rates(NamedTuple):
|
|
remaining: str
|
|
used: str
|
|
total: str
|
|
reset_when: Union[datetime, str]
|
|
last_request: Union[datetime, str]
|
|
|
|
|
|
# aiohttp request tracking / checking bits
|
|
async def on_req_start(
|
|
session: aiohttp.ClientSession, ctx: SimpleNamespace, params: aiohttp.TraceRequestStartParams
|
|
) -> None:
|
|
"""Before-request hook to make sure we don't overrun the ratelimit."""
|
|
# print(repr(session), repr(ctx), repr(params))
|
|
pass
|
|
|
|
|
|
async def on_req_end(session: aiohttp.ClientSession, ctx: SimpleNamespace, params: aiohttp.TraceRequestEndParams) -> None:
|
|
"""After-request hook to adjust remaining requests on this time frame."""
|
|
headers = params.response.headers
|
|
|
|
remaining = headers['X-RateLimit-Remaining']
|
|
used = headers['X-RateLimit-Used']
|
|
total = headers['X-RateLimit-Limit']
|
|
reset_when = datetime.fromtimestamp(int(headers['X-RateLimit-Reset']))
|
|
last_req = datetime.utcnow()
|
|
|
|
session._rates = Rates(remaining, used, total, reset_when, last_req)
|
|
|
|
|
|
trace_config = aiohttp.TraceConfig()
|
|
trace_config.on_request_start.append(on_req_start)
|
|
trace_config.on_request_end.append(on_req_end)
|
|
|
|
APIType: TypeAlias = Union[User, Gist, Repository]
|
|
|
|
|
|
async def make_session(*, headers: Dict[str, str], authorization: Union[aiohttp.BasicAuth, None]) -> aiohttp.ClientSession:
|
|
"""This makes the ClientSession, attaching the trace config and ensuring a UA header is present."""
|
|
if not headers.get('User-Agent'):
|
|
headers[
|
|
'User-Agent'
|
|
] = f'Github-API-Wrapper (https://github.com/VarMonke/Github-Api-Wrapper) @ {__version__} Python {platform.python_version()} aiohttp {aiohttp.__version__}'
|
|
|
|
session = aiohttp.ClientSession(auth=authorization, headers=headers, trace_configs=[trace_config])
|
|
session._rates = Rates('', '', '', '', '')
|
|
return session
|
|
|
|
|
|
# pagination
|
|
class Paginator:
|
|
"""This class handles pagination for objects like Repos and Orgs."""
|
|
|
|
def __init__(self, session: aiohttp.ClientSession, response: aiohttp.ClientResponse, target_type: str):
|
|
self.session = session
|
|
self.response = response
|
|
self.should_paginate = bool(self.response.headers.get('Link', False))
|
|
types: Dict[str, Type[APIType]] = { # note: the type checker doesnt see subclasses like that
|
|
'user': User,
|
|
'gist': Gist,
|
|
'repo': Repository,
|
|
}
|
|
self.target_type: Type[APIType] = types[target_type]
|
|
self.pages = {}
|
|
self.is_exhausted = False
|
|
self.current_page = 1
|
|
self.next_page = self.current_page + 1
|
|
self.parse_header(response)
|
|
|
|
async def fetch_page(self, link: str) -> Dict[str, Union[str, int]]:
|
|
"""Fetches a specific page and returns the JSON."""
|
|
return await (await self.session.get(link)).json()
|
|
|
|
async def early_return(self) -> List[APIType]:
|
|
# I don't rightly remember what this does differently, may have a good ol redesign later
|
|
return [self.target_type(data, self) for data in await self.response.json()] # type: ignore
|
|
|
|
async def exhaust(self) -> List[APIType]:
|
|
"""Iterates through all of the pages for the relevant object and creates them."""
|
|
if self.should_paginate:
|
|
return await self.early_return()
|
|
|
|
out: List[APIType] = []
|
|
for page in range(1, self.max_page + 1):
|
|
result = await self.session.get(self.bare_link + str(page))
|
|
out.extend([self.target_type(item, self) for item in await result.json()]) # type: ignore
|
|
|
|
self.is_exhausted = True
|
|
return out
|
|
|
|
def parse_header(self, response: aiohttp.ClientResponse) -> None:
|
|
"""Predicts wether a call will exceed the ratelimit ahead of the call."""
|
|
header = response.headers['Link']
|
|
groups = LINK_PARSING_RE.findall(header)
|
|
self.max_page = int(groups[1][1])
|
|
if int(response.headers['X-RateLimit-Remaining']) < self.max_page:
|
|
raise WillExceedRatelimit(response, self.max_page)
|
|
self.bare_link = groups[0][0][:-1]
|
|
|
|
|
|
# GithubUserData = GithubRepoData = GithubIssueData = GithubOrgData = GithubGistData = Dict[str, Union [str, int]]
|
|
# Commentnig this out for now, consider using TypeDict's instead in the future <3
|
|
|
|
|
|
class http:
|
|
def __init__(self, headers: Dict[str, Union[str, int]], auth: Union[aiohttp.BasicAuth, None]) -> None:
|
|
if not headers.get('User-Agent'):
|
|
headers[
|
|
'User-Agent'
|
|
] = f'Github-API-Wrapper (https://github.com/VarMonke/Github-Api-Wrapper) @ {__version__} Python/{platform.python_version()} aiohttp/{aiohttp.__version__}'
|
|
|
|
self._rates = Rates('', '', '', '', '')
|
|
self.headers = headers
|
|
self.auth = auth
|
|
|
|
def __await__(self):
|
|
return self.start().__await__()
|
|
|
|
async def start(self):
|
|
self.session = aiohttp.ClientSession(
|
|
headers=self.headers, # type: ignore
|
|
auth=self.auth,
|
|
trace_configs=[trace_config],
|
|
)
|
|
if not hasattr(self.session, "_rates"):
|
|
self.session._rates = Rates('', '', '', '', '')
|
|
return self
|
|
|
|
def update_headers(self, *, flush: bool = False, new_headers: Dict[str, Union[str, int]]):
|
|
if flush:
|
|
from multidict import CIMultiDict
|
|
|
|
self.session._default_headers = CIMultiDict(**new_headers) # type: ignore
|
|
else:
|
|
self.session._default_headers = {**self.session.headers, **new_headers} # type: ignore
|
|
|
|
async def update_auth(self, *, username: str, token: str):
|
|
auth = aiohttp.BasicAuth(username, token)
|
|
headers = self.session.headers
|
|
config = self.session.trace_configs
|
|
await self.session.close()
|
|
self.session = aiohttp.ClientSession(headers=headers, auth=auth, trace_configs=config)
|
|
|
|
def data(self):
|
|
# return session headers and auth
|
|
headers = {**self.session.headers}
|
|
return {'headers': headers, 'auth': self.auth}
|
|
|
|
async def latency(self):
|
|
"""Returns the latency of the current session."""
|
|
start = datetime.utcnow()
|
|
await self.session.get(BASE_URL)
|
|
return (datetime.utcnow() - start).total_seconds()
|
|
|
|
async def get_self(self) -> Dict[str, Union[str, int]]:
|
|
"""Returns the authenticated User's data"""
|
|
result = await self.session.get(SELF_URL)
|
|
if 200 <= result.status <= 299:
|
|
return await result.json()
|
|
raise InvalidToken
|
|
|
|
async def get_user(self, username: str) -> Dict[str, Union[str, int]]:
|
|
"""Returns a user's public data in JSON format."""
|
|
result = await self.session.get(USERS_URL.format(username))
|
|
if 200 <= result.status <= 299:
|
|
return await result.json()
|
|
raise UserNotFound
|
|
|
|
async def get_user_repos(self, _user: User) -> List[Dict[str, Union[str, int]]]:
|
|
result = await self.session.get(USER_REPOS_URL.format(_user.login))
|
|
if 200 <= result.status <= 299:
|
|
return await result.json()
|
|
|
|
print('This shouldn\'t be reachable')
|
|
return []
|
|
|
|
async def get_user_gists(self, _user: User) -> List[Dict[str, Union[str, int]]]:
|
|
result = await self.session.get(USER_GISTS_URL.format(_user.login))
|
|
if 200 <= result.status <= 299:
|
|
return await result.json()
|
|
|
|
print('This shouldn\'t be reachable')
|
|
return []
|
|
|
|
async def get_user_orgs(self, _user: User) -> List[Dict[str, Union[str, int]]]:
|
|
result = await self.session.get(USER_ORGS_URL.format(_user.login))
|
|
if 200 <= result.status <= 299:
|
|
return await result.json()
|
|
|
|
print('This shouldn\'t be reachable')
|
|
return []
|
|
|
|
async def get_repo(self, owner: str, repo_name: str) -> Optional[Dict[str, Union[str, int]]]:
|
|
"""Returns a Repo's raw JSON from the given owner and repo name."""
|
|
result = await self.session.get(REPO_URL.format(owner, repo_name))
|
|
if 200 <= result.status <= 299:
|
|
return await result.json()
|
|
raise RepositoryNotFound
|
|
|
|
async def get_repo_issue(self, owner: str, repo_name: str, issue_number: int) -> Optional[Dict[str, Any]]:
|
|
"""Returns a single issue's JSON from the given owner and repo name."""
|
|
result = await self.session.get(REPO_ISSUE_URL.format(owner, repo_name, issue_number))
|
|
if 200 <= result.status <= 299:
|
|
return await result.json()
|
|
raise IssueNotFound
|
|
|
|
async def delete_repo(self, owner: Optional[str], repo_name: str) -> Optional[str]:
|
|
"""Deletes a Repo from the given owner and repo name."""
|
|
result = await self.session.delete(REPO_URL.format(owner, repo_name))
|
|
if 204 <= result.status <= 299:
|
|
return 'Successfully deleted repository.'
|
|
if result.status == 403: # type: ignore
|
|
raise MissingPermissions
|
|
raise RepositoryNotFound
|
|
|
|
async def delete_gist(self, gist_id: Union[str, int]) -> Optional[str]:
|
|
"""Deletes a Gist from the given gist id."""
|
|
result = await self.session.delete(GIST_URL.format(gist_id))
|
|
if result.status == 204:
|
|
return 'Successfully deleted gist.'
|
|
if result.status == 403:
|
|
raise MissingPermissions
|
|
raise GistNotFound
|
|
|
|
async def get_org(self, org_name: str) -> Dict[str, Union[str, int]]:
|
|
"""Returns an org's public data in JSON format.""" # type: ignore
|
|
result = await self.session.get(ORG_URL.format(org_name))
|
|
if 200 <= result.status <= 299:
|
|
return await result.json()
|
|
raise OrganizationNotFound
|
|
|
|
async def get_gist(self, gist_id: str) -> Dict[str, Union[str, int]]:
|
|
"""Returns a gist's raw JSON from the given gist id."""
|
|
result = await self.session.get(GIST_URL.format(gist_id))
|
|
if 200 <= result.status <= 299:
|
|
return await result.json()
|
|
raise GistNotFound
|
|
|
|
async def create_gist(
|
|
self, *, files: List['File'] = [], description: str = 'Default description', public: bool = False
|
|
) -> Dict[str, Union[str, int]]:
|
|
data = {}
|
|
data['description'] = description
|
|
data['public'] = public
|
|
data['files'] = {}
|
|
for file in files:
|
|
data['files'][file.filename] = {'filename': file.filename, 'content': file.read()} # helps editing the file
|
|
data = json.dumps(data)
|
|
_headers = dict(self.session.headers)
|
|
result = await self.session.post(
|
|
CREATE_GIST_URL, data=data, headers=_headers | {'Accept': 'application/vnd.github.v3+json'}
|
|
)
|
|
if 201 == result.status:
|
|
return await result.json()
|
|
raise InvalidToken
|
|
|
|
async def create_repo(
|
|
self, name: str, description: str, public: bool, gitignore: Optional[str], license: Optional[str]
|
|
) -> Dict[str, Union[str, int]]:
|
|
"""Creates a repo for you with given data"""
|
|
data = {
|
|
'name': name,
|
|
'description': description,
|
|
'public': public,
|
|
'gitignore_template': gitignore,
|
|
'license': license,
|
|
}
|
|
result = await self.session.post(CREATE_REPO_URL, data=json.dumps(data))
|
|
if 200 <= result.status <= 299:
|
|
return await result.json()
|
|
if result.status == 401:
|
|
raise NoAuthProvided
|
|
raise RepositoryAlreadyExists
|
|
|
|
async def add_file(self, owner: str, repo_name: str, filename: str, content: str, message: str, branch: str):
|
|
"""Adds a file to the given repo."""
|
|
|
|
data = {
|
|
'content': bytes_to_b64(content=content),
|
|
'message': message,
|
|
'branch': branch,
|
|
}
|
|
|
|
result = await self.session.put(ADD_FILE_URL.format(owner, repo_name, filename), data=json.dumps(data))
|
|
if 200 <= result.status <= 299:
|
|
return await result.json()
|
|
if result.status == 401:
|
|
raise NoAuthProvided
|
|
if result.status == 409:
|
|
raise FileAlreadyExists
|
|
if result.status == 422:
|
|
raise FileAlreadyExists('This file exists, and can only be edited.')
|
|
return await result.json(), result.status
|