diff --git a/Github/http.py b/Github/http.py index cf6c002..3d2296e 100644 --- a/Github/http.py +++ b/Github/http.py @@ -2,10 +2,15 @@ import aiohttp from types import SimpleNamespace +import re from .exceptions import * +from .objects import * +from .objects import APIOBJECT from .urls import * +LINK_PARSING_RE = re.compile(r"<(\S+(\S))>; rel=\"(\S+)\"") + # aiohttp request tracking / checking bits async def on_req_start( session: aiohttp.ClientSession, @@ -39,6 +44,51 @@ async def make_session(*, headers: dict[str, str], authorization: aiohttp.BasicA ) return session +# pagination +class Paginator: + """This class handles pagination for objects like Repos and Orgs.""" + def __init__(self, session: aiohttp.ClientSession, response: aiohttp.ClientResponse, target_type: str): + self.session = session + self.response = response + self.should_paginate = bool(self.response.headers.get('Link', False)) + types: dict[str, User | ...] = { + 'user': User, + } + self.target_type = types[target_type] + self.pages = {} + self.is_exhausted = False + self.current_page = 1 + self.next_page = self.current_page + 1 + self.parse_header(response) + + async def fetch_page(self, link) -> dict[str, str | int]: + """Fetches a specific page and returns the JSON.""" + return await (await self.session.get(link)).json() + + async def early_return(self) -> list[APIOBJECT]: + # I don't rightly remember what this does differently, may have a good ol redesign later + return [self.target_type(data, self.session) for data in await self.response.json()] + + async def exhaust(self) -> list[APIOBJECT]: + """Iterates through all of the pages for the relevant object and creates them.""" + if self.should_paginate: + return await self.early_return() + out = [] + for page in range(1, self.max_page+1): + result = await self.session.get(self.bare_link + str(page)) + out.extend([self.target_type(item, self.session) for item in await result.json()]) + self.is_exhausted = True + return out + + def parse_header(self, response: aiohttp.ClientResponse) -> None: + """Predicts wether a call will exceed the ratelimit ahead of the call.""" + header = response.headers.get('Link') + groups = LINK_PARSING_RE.findall(header) + self.max_page = int(groups[1][1]) + if int(response.headers['X-RateLimit-Remaining']) < self.max_page: + raise WillExceedRatelimit(response, self.max_page) + self.bare_link = groups[0][0][:-1] + # user-related functions / utils GitHubUserData = dict[str, str | int]