Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/splatnet3_scraper/auth/tokens/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,7 @@ def ready_for_endpoint(self, endpoint: str = "web") -> bool:
return now < self._web_service_token_expires_at

def ensure_tokens_valid(self) -> None:
now = time.time()
if now < self.next_available_at:
if time.time() < self.next_available_at:
raise AccountCooldownException(
"Account is cooling down for another"
f" {self.cooldown_remaining():.1f} seconds."
Expand All @@ -294,9 +293,6 @@ def ensure_tokens_valid(self) -> None:
if bullet_token.is_expired:
self.generate_bullet_token()

if now >= self._id_token_expires_at:
self.regenerate_tokens()

def record_response(self, status_code: int) -> None:
self.last_status_code = status_code
if status_code in (429, 503):
Expand Down
85 changes: 76 additions & 9 deletions src/splatnet3_scraper/query/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
)

T = TypeVar("T")
TOKEN_TIMESTAMP_OPTIONS = {
TOKENS.GTOKEN: "gtoken_timestamp",
TOKENS.BULLET_TOKEN: "bullet_token_timestamp",
}


class Config:
Expand Down Expand Up @@ -104,10 +108,7 @@ def regenerate_tokens(self) -> None:
TOKENS.GTOKEN,
TOKENS.BULLET_TOKEN,
]:
self.handler.set_value(
token,
self.token_manager.get_token(token).value,
)
self._sync_token_to_handler(token)
if self._output_file_path is not None:
try:
self.save_to_file()
Expand Down Expand Up @@ -251,6 +252,25 @@ def _configure_token_manager(self) -> None:
exc,
)

@staticmethod
def _timestamp_option_for_token(token_name: str) -> str | None:
return TOKEN_TIMESTAMP_OPTIONS.get(token_name)

def _sync_token_to_handler(self, token_name: str) -> None:
token = self.token_manager.get_token(token_name)
self.handler.set_value(token_name, token.value)
if timestamp_option := self._timestamp_option_for_token(token_name):
self.handler.set_value(timestamp_option, str(token.timestamp))

def _get_token_timestamp(self, token_name: str) -> float | None:
timestamp_option = self._timestamp_option_for_token(token_name)
if timestamp_option is None:
return None
try:
return cast(float | None, self.handler.get_value(timestamp_option))
except ValueError:
return None

def get_value(
self, option: str, default: T | None = None
) -> str | T | None:
Expand Down Expand Up @@ -282,10 +302,29 @@ def set_value(self, option: str, value: str | None) -> None:
TOKENS.GTOKEN,
TOKENS.BULLET_TOKEN,
]:
timestamp = self._get_token_timestamp(option)
if (token := self.handler.tokens[option]) is not None:
self.token_manager.add_token(
token,
option,
timestamp=timestamp,
)
if (
timestamp is None
and self._timestamp_option_for_token(option) is not None
):
self._sync_token_to_handler(option)
elif option in TOKEN_TIMESTAMP_OPTIONS.values():
token_name = next(
name
for name, timestamp_option in TOKEN_TIMESTAMP_OPTIONS.items()
if timestamp_option == option
)
if (token := self.handler.tokens[token_name]) is not None:
self.token_manager.add_token(
token,
token_name,
timestamp=self._get_token_timestamp(token_name),
)
elif option == "app_version_override":
self._apply_app_version_override()
Expand Down Expand Up @@ -354,9 +393,17 @@ def from_tokens(

prefix = prefix or Config.DEFAULT_PREFIX
handler = ConfigOptionHandler(prefix=prefix)
handler.set_value(TOKENS.SESSION_TOKEN, session_token)
handler.set_value(TOKENS.GTOKEN, gtoken)
handler.set_value(TOKENS.BULLET_TOKEN, bullet_token)
handler.set_value(
TOKENS.SESSION_TOKEN,
token_manager.get_token(TOKENS.SESSION_TOKEN).value,
)
for token_name in (TOKENS.GTOKEN, TOKENS.BULLET_TOKEN):
token = token_manager.get_token(token_name)
handler.set_value(token_name, token.value)
handler.set_value(
cast(str, TOKEN_TIMESTAMP_OPTIONS[token_name]),
str(token.timestamp),
)
handler.set_value("app_version_override", app_version)

return Config(
Expand Down Expand Up @@ -392,10 +439,22 @@ def from_config_handler(
gtoken = handler.get_value(TOKENS.GTOKEN)
except ValueError:
gtoken = None
try:
gtoken_timestamp = cast(
float | None, handler.get_value("gtoken_timestamp")
)
except ValueError:
gtoken_timestamp = None
try:
bullet_token = handler.get_value(TOKENS.BULLET_TOKEN)
except ValueError:
bullet_token = None
try:
bullet_token_timestamp = cast(
float | None, handler.get_value("bullet_token_timestamp")
)
except ValueError:
bullet_token_timestamp = None
try:
app_version = handler.get_value("app_version_override")
except ValueError:
Expand All @@ -408,9 +467,17 @@ def from_config_handler(
app_version=cast(str | None, app_version),
)
if gtoken is not None:
token_manager.add_token(gtoken, TOKENS.GTOKEN)
token_manager.add_token(
gtoken,
TOKENS.GTOKEN,
timestamp=gtoken_timestamp,
)
if bullet_token is not None:
token_manager.add_token(bullet_token, TOKENS.BULLET_TOKEN)
token_manager.add_token(
bullet_token,
TOKENS.BULLET_TOKEN,
timestamp=bullet_token_timestamp,
)
if gtoken is not None and bullet_token is not None:
token_manager.mark_tokens_fresh()

Expand Down
14 changes: 14 additions & 0 deletions src/splatnet3_scraper/query/config/config_option_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,20 @@ class ConfigOptionHandler:
deprecated_names=["bullettoken"],
env_var="BULLET_TOKEN",
),
ConfigOption[float](
name="gtoken_timestamp",
default=None,
section="tokens",
callback=float,
save_callback=lambda value: str(value),
),
ConfigOption[float](
name="bullet_token_timestamp",
default=None,
section="tokens",
callback=float,
save_callback=lambda value: str(value),
),
ConfigOption[str](
name="user_agent",
default=DEFAULT_USER_AGENT,
Expand Down
29 changes: 28 additions & 1 deletion tests/auth/tokens/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,4 +423,31 @@ def get_token(name: str, *, full_token: bool = True):
mock_token_manager.ensure_tokens_valid()

mock_token_manager.generate_gtoken.assert_called_once()
mock_token_manager.regenerate_tokens.assert_called_once()
mock_token_manager.regenerate_tokens.assert_not_called()

def test_ensure_tokens_valid_does_not_force_full_refresh_for_stale_app_timer(
self, mock_token_manager: TokenManager, monkeypatch: pytest.MonkeyPatch
) -> None:
tokens = {
TOKENS.GTOKEN: MagicMock(is_expired=False, value="gtoken"),
TOKENS.BULLET_TOKEN: MagicMock(is_expired=False, value="bullet"),
}

def get_token(name: str, *, full_token: bool = True):
return tokens[name]

mock_token_manager.keychain.get.side_effect = get_token
mock_token_manager.generate_gtoken = MagicMock()
mock_token_manager.generate_bullet_token = MagicMock()
mock_token_manager.regenerate_tokens = MagicMock()
mock_token_manager._id_token_expires_at = 400.0

monkeypatch.setattr(
f"{base_token_manager_path}.time.time", lambda: 500.0
)

mock_token_manager.ensure_tokens_valid()

mock_token_manager.generate_gtoken.assert_not_called()
mock_token_manager.generate_bullet_token.assert_not_called()
mock_token_manager.regenerate_tokens.assert_not_called()
Loading
Loading