From 59029f049df6d709cb4645d2e087e8772a49b00b Mon Sep 17 00:00:00 2001 From: Paillat Date: Sun, 28 Jul 2024 19:26:13 +0200 Subject: [PATCH] Fix: sqlalchemy errors break session --- src/db_adapters/__init__.py | 4 ++ src/db_adapters/lists.py | 79 +++++++++++++--------- src/db_adapters/media.py | 43 +++++++----- src/db_adapters/misc.py | 9 +++ src/db_adapters/user.py | 51 +++++++------- src/exts/tvdb_info/ui/episode_view.py | 2 +- src/exts/tvdb_info/ui/movie_series_view.py | 8 +-- src/exts/tvdb_info/ui/profile_view.py | 38 ++++++----- src/exts/tvdb_info/ui/search_view.py | 4 +- 9 files changed, 139 insertions(+), 99 deletions(-) create mode 100644 src/db_adapters/misc.py diff --git a/src/db_adapters/__init__.py b/src/db_adapters/__init__.py index e220e57..e6881be 100644 --- a/src/db_adapters/__init__.py +++ b/src/db_adapters/__init__.py @@ -7,6 +7,8 @@ list_remove_item_safe, refresh_list_items, ) +from .media import series_get +from .misc import refresh from .user import user_create_list, user_get, user_get_list_safe, user_get_safe __all__ = [ @@ -21,4 +23,6 @@ "refresh_list_items", "get_list_item", "list_remove_item_safe", + "refresh", + "series_get", ] diff --git a/src/db_adapters/lists.py b/src/db_adapters/lists.py index 2230c98..dadc740 100644 --- a/src/db_adapters/lists.py +++ b/src/db_adapters/lists.py @@ -32,39 +32,47 @@ async def list_put_item( :raises ValueError: If the item is already present in the list. """ - if series_id: - await ensure_media(session, tvdb_id, kind, series_id=series_id) - else: - await ensure_media(session, tvdb_id, kind) - if await session.get(UserListItem, (user_list.id, tvdb_id, kind)) is not None: - raise ValueError(f"Item {tvdb_id} is already in list {user_list.id}.") - - item = UserListItem(list_id=user_list.id, tvdb_id=tvdb_id, kind=kind) - session.add(item) - await session.commit() - return item + async with session: + if series_id: + await ensure_media(session, tvdb_id, kind, series_id=series_id) + else: + await ensure_media(session, tvdb_id, kind) + if await session.get(UserListItem, (user_list.id, tvdb_id, kind)) is not None: + raise ValueError(f"Item {tvdb_id} is already in list {user_list.id}.") + + item = UserListItem(list_id=user_list.id, tvdb_id=tvdb_id, kind=kind) + session.add(item) + await session.commit() + return item async def list_get_item( session: AsyncSession, user_list: UserList, tvdb_id: int, kind: UserListItemKind ) -> UserListItem | None: """Get an item from a user list.""" - return await session.get(UserListItem, (user_list.id, tvdb_id, kind)) + async with session: + return await session.get(UserListItem, (user_list.id, tvdb_id, kind)) -async def list_remove_item(session: AsyncSession, user_list: UserList, item: UserListItem) -> None: +async def list_remove_item(session: AsyncSession, user_list: UserList, item: UserListItem) -> UserList: """Remove an item from a user list.""" - await session.delete(item) - await session.commit() - await session.refresh(user_list, ["items"]) + async with session: + item = await session.merge(item) + user_list = await session.merge(user_list) + await session.delete(item) + await session.commit() + await session.refresh(user_list, ["items"]) + return user_list async def list_remove_item_safe( session: AsyncSession, user_list: UserList, tvdb_id: int, kind: UserListItemKind -) -> None: +) -> UserList: """Removes an item from a user list if it exists.""" - if item := await list_get_item(session, user_list, tvdb_id, kind): - await list_remove_item(session, user_list, item) + async with session: + if item := await list_get_item(session, user_list, tvdb_id, kind): + return await list_remove_item(session, user_list, item) + return user_list @overload @@ -90,23 +98,27 @@ async def list_put_item_safe( session: AsyncSession, user_list: UserList, tvdb_id: int, kind: UserListItemKind, series_id: int | None = None ) -> UserListItem: """Add an item to a user list, or return the existing item if it is already present.""" - if series_id: - await ensure_media(session, tvdb_id, kind, series_id=series_id) - else: - await ensure_media(session, tvdb_id, kind) - item = await list_get_item(session, user_list, tvdb_id, kind) - if item: + async with session: + if series_id: + await ensure_media(session, tvdb_id, kind, series_id=series_id) + else: + await ensure_media(session, tvdb_id, kind) + item = await list_get_item(session, user_list, tvdb_id, kind) + if item: + return item + + item = UserListItem(list_id=user_list.id, tvdb_id=tvdb_id, kind=kind) + session.add(item) + await session.commit() return item - item = UserListItem(list_id=user_list.id, tvdb_id=tvdb_id, kind=kind) - session.add(item) - await session.commit() - return item - -async def refresh_list_items(session: AsyncSession, user_list: UserList) -> None: +async def refresh_list_items(session: AsyncSession, user_list: UserList) -> UserList: """Refresh the items in a user list.""" - await session.refresh(user_list, ["items"]) + async with session: + user_list = await session.merge(user_list) + await session.refresh(user_list, ["items"]) + return user_list async def get_list_item( @@ -116,4 +128,5 @@ async def get_list_item( kind: UserListItemKind, ) -> UserListItem | None: """Get a user list.""" - return await session.get(UserListItem, (user_list.id, tvdb_id, kind)) + async with session: + return await session.get(UserListItem, (user_list.id, tvdb_id, kind)) diff --git a/src/db_adapters/media.py b/src/db_adapters/media.py index 5366640..11fb2c2 100644 --- a/src/db_adapters/media.py +++ b/src/db_adapters/media.py @@ -8,22 +8,29 @@ async def ensure_media(session: AsyncSession, tvdb_id: int, kind: UserListItemKind, **kwargs: Any) -> None: """Ensure that a tvdb media item is present in its respective table.""" - match kind: - case UserListItemKind.MOVIE: - cls = Movie - case UserListItemKind.SERIES: - cls = Series - case UserListItemKind.EPISODE: - cls = Episode - media = await session.get(cls, tvdb_id) - if media is None: - media = cls(tvdb_id=tvdb_id, **kwargs) - session.add(media) - await session.commit() - - if isinstance(media, Episode): - await session.refresh(media, ["series"]) - if not media.series: - series = Series(tvdb_id=kwargs["series_id"]) - session.add(series) + async with session: + match kind: + case UserListItemKind.MOVIE: + cls = Movie + case UserListItemKind.SERIES: + cls = Series + case UserListItemKind.EPISODE: + cls = Episode + media = await session.get(cls, tvdb_id) + if media is None: + media = cls(tvdb_id=tvdb_id, **kwargs) + session.add(media) await session.commit() + + if isinstance(media, Episode): + await session.refresh(media, ["series"]) + if not media.series: + series = Series(tvdb_id=kwargs["series_id"]) + session.add(series) + await session.commit() + + +async def series_get(session: AsyncSession, tvdb_id: int) -> Series | None: + """Get a series from the database.""" + async with session: + return await session.get(Series, tvdb_id) diff --git a/src/db_adapters/misc.py b/src/db_adapters/misc.py new file mode 100644 index 0000000..35a620d --- /dev/null +++ b/src/db_adapters/misc.py @@ -0,0 +1,9 @@ +from sqlalchemy.ext.asyncio import AsyncSession + + +async def refresh[T](session: AsyncSession, item: T, fields: list[str]) -> T: + """Refresh a media item with the specified fields.""" + async with session: + item = await session.merge(item) + await session.refresh(item, fields) + return item diff --git a/src/db_adapters/user.py b/src/db_adapters/user.py index cc52170..69ad9fe 100644 --- a/src/db_adapters/user.py +++ b/src/db_adapters/user.py @@ -7,16 +7,18 @@ async def user_get(session: AsyncSession, discord_id: int) -> User | None: """Get a user by their Discord ID.""" - return await session.get(User, discord_id) + async with session: + return await session.get(User, discord_id) async def user_get_safe(session: AsyncSession, discord_id: int) -> User: """Get a user by their Discord ID, creating them if they don't exist.""" - user = await user_get(session, discord_id) - if user is None: - user = User(discord_id=discord_id) - session.add(user) - await session.commit() + async with session: + user = await user_get(session, discord_id) + if user is None: + user = User(discord_id=discord_id) + session.add(user) + await session.commit() return user @@ -24,13 +26,14 @@ async def user_get_safe(session: AsyncSession, discord_id: int) -> User: async def user_get_list(session: AsyncSession, user: User, name: str) -> UserList | None: """Get a user's list by name.""" # use where clause on user.id and name - user_list = await session.execute( - select(UserList) - .where( - UserList.user_id == user.discord_id, + async with session: + user_list = await session.execute( + select(UserList) + .where( + UserList.user_id == user.discord_id, + ) + .where(UserList.name == name) ) - .where(UserList.name == name) - ) return user_list.scalars().first() @@ -39,14 +42,15 @@ async def user_create_list(session: AsyncSession, user: User, name: str, item_ki :raises ValueError: If a list with the same name already exists for the user. """ - if await user_get_list(session, user, name) is not None: - raise ValueError(f"List with name {name} already exists for user {user.discord_id}.") - user_list = UserList(user_id=user.discord_id, name=name, item_kind=item_kind) - session.add(user_list) - await session.commit() - await session.refresh(user, ["lists"]) + async with session: + if await user_get_list(session, user, name) is not None: + raise ValueError(f"List with name {name} already exists for user {user.discord_id}.") + user_list = UserList(user_id=user.discord_id, name=name, item_kind=item_kind) + session.add(user_list) + await session.commit() + await session.refresh(user, ["lists"]) - return user_list + return user_list async def user_get_list_safe( @@ -57,8 +61,9 @@ async def user_get_list_safe( :param kind: The kind of list to create if it doesn't exist. :return: The user list. """ - user_list = await user_get_list(session, user, name) - if user_list is None: - user_list = await user_create_list(session, user, name, kind) + async with session: + user_list = await user_get_list(session, user, name) + if user_list is None: + user_list = await user_create_list(session, user, name, kind) - return user_list + return user_list diff --git a/src/exts/tvdb_info/ui/episode_view.py b/src/exts/tvdb_info/ui/episode_view.py index e62c074..9cbeac4 100644 --- a/src/exts/tvdb_info/ui/episode_view.py +++ b/src/exts/tvdb_info/ui/episode_view.py @@ -141,7 +141,7 @@ async def set_watched(self, state: bool) -> None: ) if item is None: raise ValueError("Episode is not marked as watched, can't re-mark as unwatched.") - await list_remove_item(self.bot.db_session, self.watched_list, item) + self.watched_list = await list_remove_item(self.bot.db_session, self.watched_list, item) else: try: await list_put_item( diff --git a/src/exts/tvdb_info/ui/movie_series_view.py b/src/exts/tvdb_info/ui/movie_series_view.py index 2c6e35f..7e8d173 100644 --- a/src/exts/tvdb_info/ui/movie_series_view.py +++ b/src/exts/tvdb_info/ui/movie_series_view.py @@ -61,7 +61,7 @@ async def set_favorite(self, state: bool) -> None: item = await get_list_item(self.bot.db_session, self.favorite_list, self.media_data.id, self._db_item_kind) if item is None: raise ValueError("Media is not marked as favorite, can't re-mark as favorite.") - await list_remove_item(self.bot.db_session, self.watched_list, item) + self.watched_list = await list_remove_item(self.bot.db_session, self.watched_list, item) else: try: await list_put_item(self.bot.db_session, self.favorite_list, self.media_data.id, self._db_item_kind) @@ -85,7 +85,7 @@ async def set_watched(self, state: bool) -> None: item = await get_list_item(self.bot.db_session, self.watched_list, self.media_data.id, self._db_item_kind) if item is None: raise ValueError("Media is not marked as watched, can't re-mark as unwatched.") - await list_remove_item(self.bot.db_session, self.watched_list, item) + self.watched_list = await list_remove_item(self.bot.db_session, self.watched_list, item) else: try: await list_put_item(self.bot.db_session, self.watched_list, self.media_data.id, self._db_item_kind) @@ -213,14 +213,14 @@ async def set_watched(self, state: bool) -> None: if not episode.id: raise ValueError("Episode has no ID") - await list_remove_item_safe( + self.watched_list = await list_remove_item_safe( self.bot.db_session, self.watched_list, episode.id, UserListItemKind.EPISODE, ) - await refresh_list_items(self.bot.db_session, self.watched_list) + self.watched_list = await refresh_list_items(self.bot.db_session, self.watched_list) else: for episode in self.media_data.episodes: if not episode.id: diff --git a/src/exts/tvdb_info/ui/profile_view.py b/src/exts/tvdb_info/ui/profile_view.py index 26ea079..2ea0910 100644 --- a/src/exts/tvdb_info/ui/profile_view.py +++ b/src/exts/tvdb_info/ui/profile_view.py @@ -1,12 +1,11 @@ import textwrap from itertools import groupby -from typing import final +from typing import TYPE_CHECKING, final import discord from src.bot import Bot -from src.db_adapters import refresh_list_items -from src.db_tables.media import Episode as EpisodeTable, Movie as MovieTable, Series as SeriesTable +from src.db_adapters import refresh, refresh_list_items, series_get from src.db_tables.user_list import UserList, UserListItemKind from src.exts.error_handler.view import ErrorHandledView from src.settings import MOVIE_EMOJI, SERIES_EMOJI @@ -15,6 +14,9 @@ from src.utils.iterators import get_first from src.utils.log import get_logger +if TYPE_CHECKING: + from src.db_tables.media import Episode as EpisodeTable, Movie as MovieTable, Series as SeriesTable + log = get_logger(__name__) @@ -47,8 +49,8 @@ def __init__( async def _initialize(self) -> None: """Initialize the view, obtaining any necessary state.""" - await refresh_list_items(self.bot.db_session, self.watched_list) - await refresh_list_items(self.bot.db_session, self.favorite_list) + self.watched_list = await refresh_list_items(self.bot.db_session, self.watched_list) + self.favorite_list = await refresh_list_items(self.bot.db_session, self.favorite_list) watched_movies: list[MovieTable] = [] watched_shows: list[SeriesTable] = [] @@ -58,15 +60,15 @@ async def _initialize(self) -> None: for item in self.watched_list.items: match item.kind: case UserListItemKind.MOVIE: - await self.bot.db_session.refresh(item, ["movie"]) - watched_movies.append(item.movie) + refreshed_item = await refresh(self.bot.db_session, item, ["movie"]) + watched_movies.append(refreshed_item.movie) case UserListItemKind.SERIES: - await self.bot.db_session.refresh(item, ["series"]) - watched_shows.append(item.series) + refreshed_item = await refresh(self.bot.db_session, item, ["series"]) + watched_shows.append(refreshed_item.series) case UserListItemKind.EPISODE: - await self.bot.db_session.refresh(item, ["episode"]) - await self.bot.db_session.refresh(item.episode, ["series"]) - watched_episodes.append(item.episode) + refreshed_item = await refresh(self.bot.db_session, item, ["episode"]) + refreshed_item.episode = await refresh(self.bot.db_session, item.episode, ["series"]) + watched_episodes.append(refreshed_item.episode) # We don't actually care about episodes in the profile view, however, we need them # because of the way shows are marked as watched (last episode watched -> show watched). @@ -89,10 +91,10 @@ async def _initialize(self) -> None: group_episode_ids = {episode.tvdb_id for episode in episodes_it} group_episode_ids.add(first_db_episode.tvdb_id) - await self.bot.db_session.refresh(first_db_episode, ["series"]) + first_db_episode = await refresh(self.bot.db_session, first_db_episode, ["series"]) if first_db_episode.series is None: # pyright: ignore[reportUnnecessaryComparison] - manual = await self.bot.db_session.get(SeriesTable, first_db_episode.series_id) + manual = await series_get(self.bot.db_session, first_db_episode.series_id) raise ValueError(f"DB series is None id={first_db_episode.series_id}, manual={manual}") if last_episode.id in group_episode_ids: @@ -106,11 +108,11 @@ async def _initialize(self) -> None: for item in self.favorite_list.items: match item.kind: case UserListItemKind.MOVIE: - await self.bot.db_session.refresh(item, ["movie"]) - favorite_movies.append(item.movie) + refreshed_item = await refresh(self.bot.db_session, item, ["movie"]) + favorite_movies.append(refreshed_item.movie) case UserListItemKind.SERIES: - await self.bot.db_session.refresh(item, ["series"]) - favorite_shows.append(item.series) + refreshed_item = await refresh(self.bot.db_session, item, ["series"]) + favorite_shows.append(refreshed_item.series) case UserListItemKind.EPISODE: raise TypeError("Found an episode in favorite list") diff --git a/src/exts/tvdb_info/ui/search_view.py b/src/exts/tvdb_info/ui/search_view.py index 9434a4d..7be2684 100644 --- a/src/exts/tvdb_info/ui/search_view.py +++ b/src/exts/tvdb_info/ui/search_view.py @@ -78,7 +78,7 @@ async def search_view(bot: Bot, user_id: int, results: Sequence[Movie | Series]) user = await user_get_safe(bot.db_session, user_id) watched_list = await user_get_list_safe(bot.db_session, user, "watched") favorite_list = await user_get_list_safe(bot.db_session, user, "favorite") - await refresh_list_items(bot.db_session, watched_list) - await refresh_list_items(bot.db_session, favorite_list) + watched_list = await refresh_list_items(bot.db_session, watched_list) + favorite_list = await refresh_list_items(bot.db_session, favorite_list) return _search_view(bot, user_id, watched_list, favorite_list, results, 0)