Skip to content

Commit 6625af7

Browse files
smellthemoonlixinguo
andauthored
enhance: upsert support autoid (#2173)
milvus-io/milvus#29258 Signed-off-by: lixinguo <[email protected]> Co-authored-by: lixinguo <[email protected]>
1 parent 8e0a27b commit 6625af7

File tree

7 files changed

+68
-46
lines changed

7 files changed

+68
-46
lines changed

pymilvus/client/grpc_handler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,6 @@ def _prepare_batch_upsert_request(
626626
entities: List,
627627
partition_name: Optional[str] = None,
628628
timeout: Optional[float] = None,
629-
is_insert: bool = True,
630629
**kwargs,
631630
):
632631
param = kwargs.get("upsert_param")
@@ -661,7 +660,7 @@ def upsert(
661660

662661
try:
663662
request = self._prepare_batch_upsert_request(
664-
collection_name, entities, partition_name, timeout, False, **kwargs
663+
collection_name, entities, partition_name, timeout, **kwargs
665664
)
666665
rf = self._stub.Upsert.future(request, timeout=timeout)
667666
if kwargs.get("_async", False) is True:

pymilvus/client/prepare.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
ResourceGroupConfig,
2626
get_consistency_level,
2727
)
28-
from .utils import traverse_info, traverse_rows_info
28+
from .utils import traverse_info, traverse_rows_info, traverse_upsert_info
2929

3030

3131
class Prepare:
@@ -462,7 +462,7 @@ def row_upsert_param(
462462
return cls._parse_row_request(request, fields_info, enable_dynamic, entities)
463463

464464
@staticmethod
465-
def _pre_batch_check(
465+
def _pre_insert_batch_check(
466466
entities: List,
467467
fields_info: Any,
468468
):
@@ -493,6 +493,34 @@ def _pre_batch_check(
493493
raise ParamError(msg)
494494
return location
495495

496+
@staticmethod
497+
def _pre_upsert_batch_check(
498+
entities: List,
499+
fields_info: Any,
500+
):
501+
for entity in entities:
502+
if (
503+
entity.get("name") is None
504+
or entity.get("values") is None
505+
or entity.get("type") is None
506+
):
507+
raise ParamError(
508+
message="Missing param in entities, a field must have type, name and values"
509+
)
510+
if not fields_info:
511+
raise ParamError(message="Missing collection meta to validate entities")
512+
513+
location, primary_key_loc = traverse_upsert_info(fields_info)
514+
515+
# though impossible from sdk
516+
if primary_key_loc is None:
517+
raise ParamError(message="primary key not found")
518+
519+
if len(entities) != len(fields_info):
520+
msg = f"number of fields: {len(fields_info)}, number of entities: {len(entities)}"
521+
raise ParamError(msg)
522+
return location
523+
496524
@staticmethod
497525
def _parse_batch_request(
498526
request: Union[milvus_types.InsertRequest, milvus_types.UpsertRequest],
@@ -533,7 +561,7 @@ def batch_insert_param(
533561
partition_name: str,
534562
fields_info: Any,
535563
):
536-
location = cls._pre_batch_check(entities, fields_info)
564+
location = cls._pre_insert_batch_check(entities, fields_info)
537565
tag = partition_name if isinstance(partition_name, str) else ""
538566
request = milvus_types.InsertRequest(collection_name=collection_name, partition_name=tag)
539567

@@ -547,7 +575,7 @@ def batch_upsert_param(
547575
partition_name: str,
548576
fields_info: Any,
549577
):
550-
location = cls._pre_batch_check(entities, fields_info)
578+
location = cls._pre_upsert_batch_check(entities, fields_info)
551579
tag = partition_name if isinstance(partition_name, str) else ""
552580
request = milvus_types.UpsertRequest(collection_name=collection_name, partition_name=tag)
553581

pymilvus/client/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,17 @@ def traverse_info(fields_info: Any):
265265
return location, primary_key_loc, auto_id_loc
266266

267267

268+
def traverse_upsert_info(fields_info: Any):
269+
location, primary_key_loc = {}, None
270+
for i, field in enumerate(fields_info):
271+
if field.get("is_primary", False):
272+
primary_key_loc = i
273+
274+
location[field["name"]] = i
275+
276+
return location, primary_key_loc
277+
278+
268279
def get_server_type(host: str):
269280
return ZILLIZ if (isinstance(host, str) and "zilliz" in host.lower()) else MILVUS
270281

pymilvus/exceptions.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,6 @@ class InvalidConsistencyLevel(MilvusException):
141141
"""Raise when consistency level is invalid"""
142142

143143

144-
class UpsertAutoIDTrueException(MilvusException):
145-
"""Raise when upsert autoID is true"""
146-
147-
148144
class ExceptionsMessage:
149145
NoHostPort = "connection configuration must contain 'host' and 'port'."
150146
HostType = "Type of 'host' must be str."
@@ -234,3 +230,4 @@ class ExceptionsMessage:
234230
ClusteringKeyOnlyOne = "Expected only one clustering key field, got [%s, %s, ...]."
235231
IsClusteringKeyType = "Param is_clustering_key must be bool type."
236232
ClusteringKeyFieldType = "Param clustering_key_field must be str type."
233+
UpsertPrimaryKeyEmpty = "Upsert need to assign pk"

pymilvus/orm/collection.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
IndexNotExistException,
3535
PartitionAlreadyExistException,
3636
SchemaNotReadyException,
37-
UpsertAutoIDTrueException,
3837
)
3938
from pymilvus.grpc_gen import schema_pb2
4039
from pymilvus.settings import Config
@@ -511,7 +510,7 @@ def insert(
511510
)
512511

513512
check_insert_schema(self.schema, data)
514-
entities = Prepare.prepare_insert_data(data, self.schema)
513+
entities = Prepare.prepare_data(data, self.schema)
515514
return conn.batch_insert(
516515
self._name,
517516
entities,
@@ -622,9 +621,6 @@ def upsert(
622621
10
623622
"""
624623

625-
if self.schema.auto_id:
626-
raise UpsertAutoIDTrueException(message=ExceptionsMessage.UpsertAutoIDTrue)
627-
628624
if not is_valid_insert_data(data):
629625
raise DataTypeNotSupportException(
630626
message="The type of data should be List, pd.DataFrame or Dict"
@@ -643,7 +639,7 @@ def upsert(
643639
return MutationResult(res)
644640

645641
check_upsert_schema(self.schema, data)
646-
entities = Prepare.prepare_upsert_data(data, self.schema)
642+
entities = Prepare.prepare_data(data, self.schema, False)
647643
res = conn.upsert(
648644
self._name,
649645
entities,

pymilvus/orm/prepare.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,24 @@
1616
import numpy as np
1717
import pandas as pd
1818

19-
from pymilvus.client import utils
2019
from pymilvus.client.types import DataType
2120
from pymilvus.exceptions import (
2221
DataNotMatchException,
2322
DataTypeNotSupportException,
2423
ExceptionsMessage,
2524
ParamError,
26-
UpsertAutoIDTrueException,
2725
)
2826

2927
from .schema import CollectionSchema
3028

3129

3230
class Prepare:
3331
@classmethod
34-
def prepare_insert_data(
32+
def prepare_data(
3533
cls,
3634
data: Union[List, Tuple, pd.DataFrame],
3735
schema: CollectionSchema,
36+
is_insert: bool = True,
3837
) -> List:
3938
if not isinstance(data, (list, tuple, pd.DataFrame)):
4039
raise DataTypeNotSupportException(message=ExceptionsMessage.DataTypeNotSupport)
@@ -46,12 +45,13 @@ def prepare_insert_data(
4645
if (
4746
schema.auto_id
4847
and schema.primary_field.name in data
48+
and is_insert
4949
and not data[schema.primary_field.name].isnull().all()
5050
):
5151
raise DataNotMatchException(message=ExceptionsMessage.AutoIDWithData)
5252
# TODO(SPARSE): support pd.SparseDtype for sparse float vector field
5353
for field in fields:
54-
if field.is_primary and field.auto_id:
54+
if field.is_primary and field.auto_id and is_insert:
5555
continue
5656
values = []
5757
if field.name in list(data.columns):
@@ -63,7 +63,7 @@ def prepare_insert_data(
6363
for i, field in enumerate(tmp_fields):
6464
# TODO Goose: Checking auto_id and is_primary only, maybe different than
6565
# schema.is_primary, schema.auto_id, need to check why and how schema is built.
66-
if field.is_primary and field.auto_id:
66+
if field.is_primary and field.auto_id and is_insert:
6767
tmp_fields.pop(i)
6868

6969
vec_dtype_checker = {
@@ -152,14 +152,3 @@ def prepare_insert_data(
152152
entities.append({"name": field.name, "type": field.dtype, "values": d})
153153

154154
return entities
155-
156-
@classmethod
157-
def prepare_upsert_data(
158-
cls,
159-
data: Union[List, Tuple, pd.DataFrame, utils.SparseMatrixInputType],
160-
schema: CollectionSchema,
161-
) -> List:
162-
if schema.auto_id:
163-
raise UpsertAutoIDTrueException(message=ExceptionsMessage.UpsertAutoIDTrue)
164-
165-
return cls.prepare_insert_data(data, schema)

pymilvus/orm/schema.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
PartitionKeyException,
2929
PrimaryKeyException,
3030
SchemaNotReadyException,
31-
UpsertAutoIDTrueException,
3231
)
3332

3433
from .constants import COMMON_TYPE_PARAMS
@@ -485,28 +484,23 @@ def _check_insert_data(data: Union[List[List], pd.DataFrame]):
485484
raise DataTypeNotSupportException(message="data should be a list of list")
486485

487486

488-
def _check_data_schema_cnt(schema: CollectionSchema, data: Union[List[List], pd.DataFrame]):
489-
tmp_fields = copy.deepcopy(schema.fields)
490-
for i, field in enumerate(tmp_fields):
491-
if field.is_primary and field.auto_id:
492-
tmp_fields.pop(i)
493-
494-
field_cnt = len(tmp_fields)
487+
def _check_data_schema_cnt(fields: List, data: Union[List[List], pd.DataFrame]):
488+
field_cnt = len(fields)
495489
is_dataframe = isinstance(data, pd.DataFrame)
496490
data_cnt = len(data.columns) if is_dataframe else len(data)
497491
if field_cnt != data_cnt:
498492
message = (
499493
f"The data don't match with schema fields, expect {field_cnt} list, got {len(data)}"
500494
)
501495
if is_dataframe:
502-
i_name = [f.name for f in tmp_fields]
496+
i_name = [f.name for f in fields]
503497
t_name = list(data.columns)
504498
message = f"The fields don't match with schema fields, expected: {i_name}, got {t_name}"
505499

506500
raise DataNotMatchException(message=message)
507501

508502
if is_dataframe:
509-
for x, y in zip(list(data.columns), tmp_fields):
503+
for x, y in zip(list(data.columns), fields):
510504
if x != y.name:
511505
raise DataNotMatchException(
512506
message=f"The name of field don't match, expected: {y.name}, got {x}"
@@ -524,17 +518,25 @@ def check_insert_schema(schema: CollectionSchema, data: Union[List[List], pd.Dat
524518
columns.remove(schema.primary_field)
525519
data = data[[columns]]
526520

527-
_check_data_schema_cnt(schema, data)
521+
tmp_fields = copy.deepcopy(schema.fields)
522+
for i, field in enumerate(tmp_fields):
523+
if field.is_primary and field.auto_id:
524+
tmp_fields.pop(i)
525+
526+
_check_data_schema_cnt(tmp_fields, data)
528527
_check_insert_data(data)
529528

530529

531530
def check_upsert_schema(schema: CollectionSchema, data: Union[List[List], pd.DataFrame]):
532531
if schema is None:
533532
raise SchemaNotReadyException(message="Schema shouldn't be None")
534-
if schema.auto_id:
535-
raise UpsertAutoIDTrueException(message=ExceptionsMessage.UpsertAutoIDTrue)
533+
if isinstance(data, pd.DataFrame):
534+
if schema.primary_field.name not in data or data[schema.primary_field.name].isnull().all():
535+
raise DataNotMatchException(message=ExceptionsMessage.UpsertPrimaryKeyEmpty)
536+
columns = list(data.columns)
537+
data = data[[columns]]
536538

537-
_check_data_schema_cnt(schema, data)
539+
_check_data_schema_cnt(copy.deepcopy(schema.fields), data)
538540
_check_insert_data(data)
539541

540542

0 commit comments

Comments
 (0)