Skip to content

Commit 3f20da8

Browse files
authored
FIX: Refactor prepare_name_pairs_pd to pass arguments to create_positive_negative_samples (#32)
* passed correct_col to prepare_name pairs and replaced hardcoded column names * passed uid_col to prepare_name pairs and changed hardcodes mentions of it * added docstring for correct_col in prepare name pairs * passed more columns to prepare name pairs and replaced their corresponding hardcoded values * passed the new columns also to spark version of training name pairs * added branch to test.yml * passed positive_set_col to create_negative_name_pairs * removed branch from test.yml
1 parent e7f5658 commit 3f20da8

File tree

3 files changed

+30
-9
lines changed

3 files changed

+30
-9
lines changed

emm/data/prepare_name_pairs.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@ def prepare_name_pairs_pd(
3939
entity_id_col="entity_id",
4040
gt_entity_id_col="gt_entity_id",
4141
positive_set_col="positive_set",
42+
correct_col="correct",
4243
uid_col="uid",
44+
gt_uid_col="gt_uid",
45+
preprocessed_col="preprocessed",
46+
gt_preprocessed_col="gt_preprocessed",
4347
random_seed=42,
4448
):
4549
"""Prepare dataset of name-pair candidates for training of supervised model.
@@ -70,7 +74,12 @@ def prepare_name_pairs_pd(
7074
For matching name-pairs entity_id == gt_entity_id.
7175
positive_set_col: column that specifies which candidates remain positive and which become negative,
7276
default is "positive_set".
77+
correct_col: column that indicates a correct match, default is "correct".
78+
For entity_id == gt_entity_id the column value is "correct".
7379
uid_col: uid column for names to match, default is "uid".
80+
gt_uid_col: uid column of ground-truth names, default is "gt_uid".
81+
preprocessed_col: name of the preprocessed names column, default is "preprocessed".
82+
gt_preprocessed_col: name of the preprocessed ground-truth names column, default is "gt_preprocessed".
7483
random_seed: random seed for selection of negative names, default is 42.
7584
"""
7685
"""We can have the following dataset.columns, or much more like 'count', 'counterparty_account_count_distinct', 'type1_sum':
@@ -84,7 +93,7 @@ def prepare_name_pairs_pd(
8493
assert entity_id_col in candidates_pd.columns
8594
assert gt_entity_id_col in candidates_pd.columns
8695

87-
candidates_pd["correct"] = candidates_pd[entity_id_col] == candidates_pd[gt_entity_id_col]
96+
candidates_pd[correct_col] = candidates_pd[entity_id_col] == candidates_pd[gt_entity_id_col]
8897

8998
# negative sample creation?
9099
# if so, add positive_set_col column for negative sample creation
@@ -110,14 +119,14 @@ def prepare_name_pairs_pd(
110119
# - happens with one correct/positive case, we just pick the correct one
111120
if drop_duplicate_candidates:
112121
candidates_pd = candidates_pd.sort_values(
113-
["uid", "gt_preprocessed", "correct"], ascending=False
114-
).drop_duplicates(subset=["uid", "gt_preprocessed"], keep="first")
122+
[uid_col, gt_preprocessed_col, correct_col], ascending=False
123+
).drop_duplicates(subset=[uid_col, gt_preprocessed_col], keep="first")
115124
# Similar, for a training set remove all equal names that are not considered a match.
116125
# This can happen a lot in actual data, e.g. with franchises that are independent but have the same name.
117126
# It's a true effect in data, but this screws up our intuitive notion that identical names should be related.
118127
if drop_samename_nomatch:
119-
samename_nomatch = (candidates_pd["preprocessed"] == candidates_pd["gt_preprocessed"]) & ~candidates_pd[
120-
"correct"
128+
samename_nomatch = (candidates_pd[preprocessed_col] == candidates_pd[gt_preprocessed_col]) & ~candidates_pd[
129+
correct_col
121130
]
122131
candidates_pd = candidates_pd[~samename_nomatch]
123132

@@ -133,7 +142,9 @@ def prepare_name_pairs_pd(
133142
# is referred to in: resources/data/howto_create_unittest_sample_namepairs.txt
134143
# create negative sample and rerank negative candidates
135144
# this drops, in part, the negative correct candidates
136-
candidates_pd = create_positive_negative_samples(candidates_pd)
145+
candidates_pd = create_positive_negative_samples(
146+
candidates_pd, uid_col=uid_col, correct_col=correct_col, positive_set_col=positive_set_col
147+
)
137148

138149
# It could be that we dropped all candidates, so we need to re-introduce the no-candidate rows
139150
names_to_match_after = candidates_pd[names_to_match_cols].drop_duplicates()
@@ -142,12 +153,12 @@ def prepare_name_pairs_pd(
142153
)
143154
names_to_match_missing = names_to_match_missing[names_to_match_missing["_merge"] == "left_only"]
144155
names_to_match_missing = names_to_match_missing.drop(columns=["_merge"])
145-
names_to_match_missing["correct"] = False
156+
names_to_match_missing[correct_col] = False
146157
# Since this column is used to calculate benchmark metrics
147158
names_to_match_missing["score_0_rank"] = 1
148159

149160
candidates_pd = pd.concat([candidates_pd, names_to_match_missing], ignore_index=True)
150-
candidates_pd["gt_preprocessed"] = candidates_pd["gt_preprocessed"].fillna("")
151-
candidates_pd["no_candidate"] = candidates_pd["gt_uid"].isnull()
161+
candidates_pd[gt_preprocessed_col] = candidates_pd[gt_preprocessed_col].fillna("")
162+
candidates_pd["no_candidate"] = candidates_pd[gt_uid_col].isnull()
152163

153164
return candidates_pd

emm/pipeline/pandas_entity_matching.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,11 @@ def create_training_name_pairs(
384384
else drop_duplicate_candidates,
385385
create_negative_sample_fraction=create_negative_sample_fraction,
386386
positive_set_col=self.parameters.get("positive_set_col", "positive_set"),
387+
correct_col=self.parameters.get("correct_col", "correct"),
388+
uid_col=self.parameters.get("uid_col", "uid"),
389+
gt_uid_col=self.parameters.get("gt_uid_col", "gt_uid"),
390+
preprocessed_col=self.parameters.get("preprocessed_col", "preprocessed"),
391+
gt_preprocessed_col=self.parameters.get("gt_preprocessed_col", "gt_preprocessed"),
387392
random_seed=random_seed,
388393
**kwargs,
389394
)

emm/pipeline/spark_entity_matching.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,11 @@ def create_training_name_pairs(
412412
else drop_duplicate_candidates,
413413
create_negative_sample_fraction=create_negative_sample_fraction,
414414
positive_set_col=self.parameters.get("positive_set_col", "positive_set"),
415+
correct_col=self.parameters.get("correct_col", "correct"),
416+
uid_col=self.parameters.get("uid_col", "uid"),
417+
gt_uid_col=self.parameters.get("gt_uid_col", "gt_uid"),
418+
preprocessed_col=self.parameters.get("preprocessed_col", "preprocessed"),
419+
gt_preprocessed_col=self.parameters.get("gt_preprocessed_col", "gt_preprocessed"),
415420
random_seed=random_seed,
416421
**kwargs,
417422
)

0 commit comments

Comments
 (0)