1
Fork 0
mirror of https://github.com/RGBCube/GitHubWrapper synced 2025-05-22 08:55:08 +00:00

refactor w/ black

This commit is contained in:
NextChai 2022-04-30 02:27:16 -04:00
parent cc3cde89c8
commit b474492637
8 changed files with 187 additions and 165 deletions

View file

@ -1,4 +1,4 @@
#== http.py ==#
# == http.py ==#
from __future__ import annotations
@ -29,60 +29,57 @@ Rates = NamedTuple('Rates', 'remaining', 'used', 'total', 'reset_when', 'last_re
# aiohttp request tracking / checking bits
async def on_req_start(
session: aiohttp.ClientSession,
ctx: SimpleNamespace,
params: aiohttp.TraceRequestStartParams
session: aiohttp.ClientSession, ctx: SimpleNamespace, params: aiohttp.TraceRequestStartParams
) -> None:
"""Before-request hook to make sure we don't overrun the ratelimit."""
#print(repr(session), repr(ctx), repr(params))
# print(repr(session), repr(ctx), repr(params))
pass
async def on_req_end(
session: aiohttp.ClientSession,
ctx: SimpleNamespace,
params: aiohttp.TraceRequestEndParams
) -> None:
async def on_req_end(session: aiohttp.ClientSession, ctx: SimpleNamespace, params: aiohttp.TraceRequestEndParams) -> None:
"""After-request hook to adjust remaining requests on this time frame."""
headers = params.response.headers
remaining = headers['X-RateLimit-Remaining']
used = headers['X-RateLimit-Used']
total = headers['X-RateLimit-Limit']
reset_when = datetime.fromtimestamp(int(headers['X-RateLimit-Reset']))
last_req = datetime.utcnow()
remaining = headers['X-RateLimit-Remaining']
used = headers['X-RateLimit-Used']
total = headers['X-RateLimit-Limit']
reset_when = datetime.fromtimestamp(int(headers['X-RateLimit-Reset']))
last_req = datetime.utcnow()
session._rates = Rates(remaining, used, total, reset_when, last_req)
trace_config = aiohttp.TraceConfig()
trace_config.on_request_start.append(on_req_start)
trace_config.on_request_end.append(on_req_end)
APIType: TypeAlias = Union[User, Gist, Repository]
async def make_session(*, headers: Dict[str, str], authorization: Union[aiohttp.BasicAuth, None]) -> aiohttp.ClientSession:
"""This makes the ClientSession, attaching the trace config and ensuring a UA header is present."""
if not headers.get('User-Agent'):
headers['User-Agent'] = f'Github-API-Wrapper (https://github.com/VarMonke/Github-Api-Wrapper) @ {__version__} Python {platform.python_version()} aiohttp {aiohttp.__version__}'
headers[
'User-Agent'
] = f'Github-API-Wrapper (https://github.com/VarMonke/Github-Api-Wrapper) @ {__version__} Python {platform.python_version()} aiohttp {aiohttp.__version__}'
session = aiohttp.ClientSession(
auth=authorization,
headers=headers,
trace_configs=[trace_config]
)
session._rates = Rates('', '' , '', '', '')
session = aiohttp.ClientSession(auth=authorization, headers=headers, trace_configs=[trace_config])
session._rates = Rates('', '', '', '', '')
return session
# pagination
class Paginator:
"""This class handles pagination for objects like Repos and Orgs."""
def __init__(self, session: aiohttp.ClientSession, response: aiohttp.ClientResponse, target_type: str):
self.session = session
self.response = response
self.should_paginate = bool(self.response.headers.get('Link', False))
types: Dict[str, Type[APIType]] = { # note: the type checker doesnt see subclasses like that
types: Dict[str, Type[APIType]] = { # note: the type checker doesnt see subclasses like that
'user': User,
'gist' : Gist,
'repo' : Repository
'gist': Gist,
'repo': Repository,
}
self.target_type: Type[APIType] = types[target_type]
self.pages = {}
@ -97,18 +94,18 @@ class Paginator:
async def early_return(self) -> List[APIType]:
# I don't rightly remember what this does differently, may have a good ol redesign later
return [self.target_type(data, self) for data in await self.response.json()] # type: ignore
return [self.target_type(data, self) for data in await self.response.json()] # type: ignore
async def exhaust(self) -> List[APIType]:
"""Iterates through all of the pages for the relevant object and creates them."""
if self.should_paginate:
return await self.early_return()
out: List[APIType] = []
for page in range(1, self.max_page+1):
for page in range(1, self.max_page + 1):
result = await self.session.get(self.bare_link + str(page))
out.extend([self.target_type(item, self) for item in await result.json()]) # type: ignore
out.extend([self.target_type(item, self) for item in await result.json()]) # type: ignore
self.is_exhausted = True
return out
@ -121,14 +118,18 @@ class Paginator:
raise WillExceedRatelimit(response, self.max_page)
self.bare_link = groups[0][0][:-1]
# GithubUserData = GithubRepoData = GithubIssueData = GithubOrgData = GithubGistData = Dict[str, Union [str, int]]
# Commentnig this out for now, consider using TypeDict's instead in the future <3
class http:
def __init__(self, headers: Dict[str, Union[str, int]], auth: Union[aiohttp.BasicAuth, None]) -> None:
if not headers.get('User-Agent'):
headers['User-Agent'] = f'Github-API-Wrapper (https://github.com/VarMonke/Github-Api-Wrapper) @ {__version__} Python/{platform.python_version()} aiohttp/{aiohttp.__version__}'
headers[
'User-Agent'
] = f'Github-API-Wrapper (https://github.com/VarMonke/Github-Api-Wrapper) @ {__version__} Python/{platform.python_version()} aiohttp/{aiohttp.__version__}'
self._rates = Rates('', '', '', '', '')
self.headers = headers
self.auth = auth
@ -138,7 +139,7 @@ class http:
async def start(self):
self.session = aiohttp.ClientSession(
headers=self.headers, # type: ignore
headers=self.headers, # type: ignore
auth=self.auth,
trace_configs=[trace_config],
)
@ -149,23 +150,20 @@ class http:
def update_headers(self, *, flush: bool = False, new_headers: Dict[str, Union[str, int]]):
if flush:
from multidict import CIMultiDict
self.session._default_headers = CIMultiDict(**new_headers) # type: ignore
self.session._default_headers = CIMultiDict(**new_headers) # type: ignore
else:
self.session._default_headers = {**self.session.headers, **new_headers} # type: ignore
self.session._default_headers = {**self.session.headers, **new_headers} # type: ignore
async def update_auth(self, *, username: str, token: str):
auth = aiohttp.BasicAuth(username, token)
headers = self.session.headers
config = self.session.trace_configs
await self.session.close()
self.session = aiohttp.ClientSession(
headers=headers,
auth=auth,
trace_configs=config
)
self.session = aiohttp.ClientSession(headers=headers, auth=auth, trace_configs=config)
def data(self):
#return session headers and auth
# return session headers and auth
headers = {**self.session.headers}
return {'headers': headers, 'auth': self.auth}
@ -175,52 +173,52 @@ class http:
await self.session.get(BASE_URL)
return (datetime.utcnow() - start).total_seconds()
async def get_self(self) -> Dict[str, Union [str, int]]:
async def get_self(self) -> Dict[str, Union[str, int]]:
"""Returns the authenticated User's data"""
result = await self.session.get(SELF_URL)
if 200 <= result.status <= 299:
return await result.json()
raise InvalidToken
async def get_user(self, username: str) -> Dict[str, Union [str, int]]:
async def get_user(self, username: str) -> Dict[str, Union[str, int]]:
"""Returns a user's public data in JSON format."""
result = await self.session.get(USERS_URL.format(username))
if 200 <= result.status <= 299:
return await result.json()
raise UserNotFound
async def get_user_repos(self, _user: User) -> List[Dict[str, Union [str, int]]]:
async def get_user_repos(self, _user: User) -> List[Dict[str, Union[str, int]]]:
result = await self.session.get(USER_REPOS_URL.format(_user.login))
if 200 <= result.status <= 299:
return await result.json()
print('This shouldn\'t be reachable')
return []
async def get_user_gists(self, _user: User) -> List[Dict[str, Union [str, int]]]:
async def get_user_gists(self, _user: User) -> List[Dict[str, Union[str, int]]]:
result = await self.session.get(USER_GISTS_URL.format(_user.login))
if 200 <= result.status <= 299:
return await result.json()
print('This shouldn\'t be reachable')
return []
async def get_user_orgs(self, _user: User) -> List[Dict[str, Union [str, int]]]:
async def get_user_orgs(self, _user: User) -> List[Dict[str, Union[str, int]]]:
result = await self.session.get(USER_ORGS_URL.format(_user.login))
if 200 <= result.status <= 299:
return await result.json()
print('This shouldn\'t be reachable')
return []
async def get_repo(self, owner: str, repo_name: str) -> Dict[str, Union [str, int]]:
async def get_repo(self, owner: str, repo_name: str) -> Dict[str, Union[str, int]]:
"""Returns a Repo's raw JSON from the given owner and repo name."""
result = await self.session.get(REPO_URL.format(owner, repo_name))
if 200 <= result.status <= 299:
return await result.json()
raise RepositoryNotFound
async def get_repo_issue(self, owner: str, repo_name: str, issue_number: int) -> Dict[str, Union [str, int]]:
async def get_repo_issue(self, owner: str, repo_name: str, issue_number: int) -> Dict[str, Union[str, int]]:
"""Returns a single issue's JSON from the given owner and repo name."""
result = await self.session.get(REPO_ISSUE_URL.format(owner, repo_name, issue_number))
if 200 <= result.status <= 299:
@ -245,14 +243,14 @@ class http:
raise MissingPermissions
raise GistNotFound
async def get_org(self, org_name: str) -> Dict[str, Union [str, int]]:
async def get_org(self, org_name: str) -> Dict[str, Union[str, int]]:
"""Returns an org's public data in JSON format."""
result = await self.session.get(ORG_URL.format(org_name))
if 200 <= result.status <= 299:
return await result.json()
raise OrganizationNotFound
async def get_gist(self, gist_id: int) -> Dict[str, Union [str, int]]:
async def get_gist(self, gist_id: int) -> Dict[str, Union[str, int]]:
"""Returns a gist's raw JSON from the given gist id."""
result = await self.session.get(GIST_URL.format(gist_id))
if 200 <= result.status <= 299:
@ -260,29 +258,26 @@ class http:
raise GistNotFound
async def create_gist(
self,
*,
files: List['File'] = [],
description: str = 'Default description',
public: bool = False
) -> Dict[str, Union [str, int]]:
self, *, files: List['File'] = [], description: str = 'Default description', public: bool = False
) -> Dict[str, Union[str, int]]:
data = {}
data['description'] = description
data['public'] = public
data['files'] = {}
for file in files:
data['files'][file.filename] = {
'filename' : file.filename, # helps editing the file
'content': file.read()
}
data['files'][file.filename] = {'filename': file.filename, 'content': file.read()} # helps editing the file
data = json.dumps(data)
_headers = dict(self.session.headers)
result = await self.session.post(CREATE_GIST_URL, data=data, headers=_headers|{'Accept': 'application/vnd.github.v3+json'})
result = await self.session.post(
CREATE_GIST_URL, data=data, headers=_headers | {'Accept': 'application/vnd.github.v3+json'}
)
if 201 == result.status:
return await result.json()
raise InvalidToken
async def create_repo(self, name: str, description: str, public: bool, gitignore: Optional[str], license: Optional[str]) -> Dict[str, Union [str, int]]:
async def create_repo(
self, name: str, description: str, public: bool, gitignore: Optional[str], license: Optional[str]
) -> Dict[str, Union[str, int]]:
"""Creates a repo for you with given data"""
data = {
'name': name,