Skip to content

Commit 07bca0e

Browse files
committed
added a batch insertion function
1 parent 68c87fc commit 07bca0e

File tree

2 files changed

+65
-3
lines changed

2 files changed

+65
-3
lines changed

ecodev_core/db_insertion.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,21 @@
1010
from typing import Union
1111

1212
import pandas as pd
13+
import progressbar
1314
from fastapi import BackgroundTasks
1415
from fastapi import UploadFile
1516
from pandas import ExcelFile
1617
from sqlmodel import Session
1718
from sqlmodel import SQLModel
19+
from sqlmodel.main import SQLModelMetaclass
1820
from sqlmodel.sql.expression import SelectOfScalar
1921

22+
from ecodev_core.db_upsertion import BATCH_SIZE
2023
from ecodev_core.logger import log_critical
2124
from ecodev_core.logger import logger_get
2225
from ecodev_core.pydantic_utils import CustomFrozen
2326
from ecodev_core.safe_utils import SimpleReturn
2427

25-
2628
log = logger_get(__name__)
2729

2830

@@ -72,7 +74,7 @@ async def insert_file(file: UploadFile, insertor: Insertor, session: Session) ->
7274
insert_data(df_raw, insertor, session)
7375

7476

75-
def insert_data(df: Union[pd.DataFrame, ExcelFile], insertor: Insertor, session: Session) -> None:
77+
def insert_data(df: Union[pd.DataFrame, ExcelFile], insertor: Insertor, session: Session) -> None:
7678
"""
7779
Inserts a csv/df into a database
7880
"""
@@ -81,6 +83,24 @@ def insert_data(df: Union[pd.DataFrame, ExcelFile], insertor: Insertor, session
8183
session.commit()
8284

8385

86+
def insert_batch_data(data: list[dict | SQLModelMetaclass],
87+
session: Session,
88+
raw_db_schema: SQLModelMetaclass | None = None) -> None:
89+
"""
90+
Insert the passed list of dicts (corresponding to db_schema) into db_schema db.
91+
Warning: this only inserts data, without checking for pre-existence.
92+
Ensure deleting the data before using it to avoid duplicates.
93+
"""
94+
db_schema = raw_db_schema or data[0].__class__
95+
batches = [data[i:i + BATCH_SIZE] for i in range(0, len(data), BATCH_SIZE)]
96+
97+
for batch in progressbar.progressbar(batches, redirect_stdout=False):
98+
for row in batch:
99+
new_object = db_schema(**row) if isinstance(row, dict) else row
100+
session.add(new_object)
101+
session.commit()
102+
103+
84104
def create_or_update(session: Session, row: Dict, insertor: Insertor) -> SQLModel:
85105
"""
86106
Create a new row in db if the selector insertor does not find existing row in db. Update the row

tests/functional/test_db_insertion.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
from fastapi import status
1010
from fastapi import UploadFile
1111
from fastapi.testclient import TestClient
12+
from sqlmodel import Field
1213
from sqlmodel import select
1314
from sqlmodel import Session
15+
from sqlmodel import SQLModel
1416

1517
from ecodev_core import AppUser
1618
from ecodev_core import attempt_to_log
@@ -27,9 +29,9 @@
2729
from ecodev_core import SimpleReturn
2830
from ecodev_core import upsert_app_users
2931
from ecodev_core.app_user import USER_INSERTOR
32+
from ecodev_core.db_insertion import insert_batch_data
3033
from ecodev_core.db_insertion import insert_file
3134

32-
3335
app = FastAPI()
3436
test_client = TestClient(app)
3537
DATA_DIR = Path('/app/tests/unitary/data')
@@ -141,3 +143,43 @@ def test_user_insertion(self):
141143
test_client.post('/user-insert', files={'file': ('filename', f)},
142144
headers={'Authorization': f'Bearer {admin_token}'})
143145
self.assertTrue(len(session.exec(select(AppUser)).all()) == 4)
146+
147+
148+
class InFoo(SQLModel, table=True): # type: ignore
149+
"""
150+
Test class to test DB insertion
151+
"""
152+
__tablename__ = 'in_foo'
153+
id: int | None = Field(default=None, primary_key=True)
154+
bar1: int = Field()
155+
156+
157+
class BatchInsertorTest(SafeTestCase):
158+
"""
159+
Class testing db batch insertion
160+
"""
161+
162+
def setUp(self):
163+
"""
164+
Class set up
165+
"""
166+
super().setUp()
167+
create_db_and_tables(InFoo)
168+
delete_table(InFoo)
169+
170+
def tearDown(self) -> None:
171+
return super().tearDown()
172+
173+
def test_insert_batch_data(self):
174+
"""
175+
Testing that the insert_batch_data function behaves as expected
176+
"""
177+
178+
data = [InFoo(bar1=1), InFoo(bar1=2), InFoo(bar1=3)]
179+
180+
with Session(engine) as session:
181+
insert_batch_data(data, session)
182+
183+
db_data = session.exec(select(InFoo)).all()
184+
185+
self.assertEqual(len(db_data), 3)

0 commit comments

Comments
 (0)