Skip to content

Commit cca0285

Browse files
committed
Add implementation of AsyncSession.run_sync()
Currently, `sqlmodel.ext.asyncio.session.AsyncSession` doesn't implement `run_sync()`, which means that any call to `run_sync()` on a sqlmodel `AsyncSession` will be dispatched to the parent `sqlalchemy.ext.asyncio.AsyncSession`. The first argument to sqlalchemy's `AsyncSession.run_sync()` is a callable whose first argument is a `sqlalchemy.orm.Session` object. If we're using this in a repo that uses sqlmodel, we'll actually be passing a callable whose first argument is a `sqlmodel.orm.session.Session`. In practice this works fine - because `sqlmodel.orm.session.Session` is derived from `sqlalchemy.orm.Session`, the implementation of `sqlalchemy.ext.asyncio.AsyncSession.run_sync()` can use the sqlmodel `Session` object in place of the sqlalchemy `Session` object. However, static analysers will complain that the argument to `run_sync()` is of the wrong type. For example, here's a warning from pyright: ``` Pyright: Error: Argument of type "(session: Session, id: UUID) -> int" cannot be assigned to parameter "fn" of type "(Session, **_P@run_sync) -> _T@run_sync" in function "run_sync"   Type "(session: Session, id: UUID) -> int" is not assignable to type "(Session, id: UUID) -> int"     Parameter 1: type "Session" is incompatible with type "Session"       "sqlalchemy.orm.session.Session" is not assignable to "sqlmodel.orm.session.Session" [reportArgumentType] ``` This commit implements a `run_sync()` method on `sqlmodel.ext.asyncio.session.AsyncSession`, which casts the callable to the correct type before dispatching it to the base class. This satisfies the static type checks.
1 parent 6c0410e commit cca0285

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

sqlmodel/ext/asyncio/session.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import (
22
Any,
3+
Callable,
4+
Concatenate,
35
Dict,
46
Mapping,
57
Optional,
@@ -17,6 +19,7 @@
1719
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
1820
from sqlalchemy.ext.asyncio.result import _ensure_sync_result
1921
from sqlalchemy.ext.asyncio.session import _EXECUTE_OPTIONS
22+
from sqlalchemy.orm import Session as _Session
2023
from sqlalchemy.orm._typing import OrmExecuteOptionsParameter
2124
from sqlalchemy.sql.base import Executable as _Executable
2225
from sqlalchemy.util.concurrency import greenlet_spawn
@@ -148,3 +151,17 @@ async def execute( # type: ignore
148151
_parent_execute_state=_parent_execute_state,
149152
_add_event=_add_event,
150153
)
154+
155+
async def run_sync[**P, T](
156+
self,
157+
fn: Callable[Concatenate[Session, P], T],
158+
*arg: P.args,
159+
**kw: P.kwargs,
160+
) -> T:
161+
base_fn = cast(Callable[Concatenate[_Session, P], T], fn)
162+
163+
return await super().run_sync(
164+
base_fn,
165+
*arg,
166+
**kw,
167+
)

0 commit comments

Comments
 (0)