Skip to content

Commit 985e058

Browse files
committed
fix: first pass of metadata issue fix
1 parent ebaa086 commit 985e058

File tree

1 file changed

+62
-26
lines changed

1 file changed

+62
-26
lines changed

scripts/aggregate_demix_by_week.py

Lines changed: 62 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
'Midwest': ['Illinois', 'Indiana', 'Michigan', 'Ohio', 'Wisconsin', 'Iowa', 'Kansas', 'Minnesota', 'Missouri', 'Nebraska', 'North Dakota', 'South Dakota'],
99
'South': ['Delaware', 'Maryland', 'Florida', 'Georgia', 'North Carolina', 'South Carolina', 'Virginia', 'District of Columbia', 'West Virginia', 'Alabama', 'Kentucky', 'Mississippi', 'Tennessee', 'Arkansas', 'Louisiana', 'Oklahoma', 'Texas'],
1010
'West': ['Arizona', 'Colorado', 'Idaho', 'Montana', 'Nevada', 'New Mexico', 'Utah', 'Wyoming', 'Alaska', 'California', 'Hawaii', 'Oregon', 'Washington']
11-
}
11+
}
12+
CENSUS_REGIONS = {region: [state.lower() for state in states] for region, states in CENSUS_REGIONS.items()}
13+
14+
def sum_unique(series):
15+
return series.unique().sum()
1216

1317
# Create state to region mapping for easier reference
1418
STATE_TO_REGION = {}
@@ -44,6 +48,48 @@
4448
df_agg['collection_date'] = pd.to_datetime(df_agg['collection_date'])
4549
df_agg['epiweek'] = df_agg['collection_date'].apply(lambda x: Week.fromdate(x))
4650

51+
# Calculate population, sample count, and site count for each region/week
52+
region_stats = df_agg.groupby(['geo_loc_region', 'epiweek']).agg({
53+
'ww_population': 'sum', # Sum the populations within each state
54+
'sra_accession': 'nunique',
55+
'collection_site_id': 'nunique'
56+
}).reset_index()
57+
58+
# Add census region to region_stats
59+
region_stats['census_region'] = region_stats['geo_loc_region'].map(STATE_TO_REGION)
60+
61+
# Print value counts for samples where census_region is null
62+
print('census_region null', df_agg[df_agg['census_region'].isna()]['geo_loc_region'].value_counts())
63+
64+
# For census regions, aggregate the already-aggregated state data to avoid double counting
65+
census_stats = region_stats.groupby(['census_region', 'epiweek']).agg({
66+
'ww_population': 'sum',
67+
'sra_accession': 'sum', # Sum the unique counts at state level for region totals
68+
'collection_site_id': 'sum'
69+
}).reset_index()
70+
71+
# For national stats, use the census region data
72+
nation_stats = census_stats.groupby(['epiweek']).agg({
73+
'ww_population': 'sum',
74+
'sra_accession': 'sum',
75+
'collection_site_id': 'sum'
76+
}).reset_index()
77+
nation_stats['geo_loc_region'] = 'USA'
78+
79+
# Create dictionaries from the corrected data
80+
population_dict = {f"{row['geo_loc_region']}_{row['epiweek']}": row['ww_population'] for _, row in region_stats.iterrows()}
81+
population_dict.update({f"{row['census_region']}_{row['epiweek']}": row['ww_population'] for _, row in census_stats.iterrows()})
82+
population_dict.update({f"USA_{row['epiweek']}": row['ww_population'] for _, row in nation_stats.iterrows()})
83+
84+
num_samples_dict = {f"{row['geo_loc_region']}_{row['epiweek']}": row['sra_accession'] for _, row in region_stats.iterrows()}
85+
num_samples_dict.update({f"{row['census_region']}_{row['epiweek']}": row['sra_accession'] for _, row in census_stats.iterrows()})
86+
num_samples_dict.update({f"USA_{row['epiweek']}": row['sra_accession'] for _, row in nation_stats.iterrows()})
87+
88+
num_sites_dict = {f"{row['geo_loc_region']}_{row['epiweek']}": row['collection_site_id'] for _, row in region_stats.iterrows()}
89+
num_sites_dict.update({f"{row['census_region']}_{row['epiweek']}": row['collection_site_id'] for _, row in census_stats.iterrows()})
90+
num_sites_dict.update({f"USA_{row['epiweek']}": row['collection_site_id'] for _, row in nation_stats.iterrows()})
91+
92+
4793
# First aggregate by state
4894
total_lineage_prevalence_state = df_agg.groupby(['epiweek', 'geo_loc_region']).agg({
4995
'pop_weighted_prevalence': 'sum',
@@ -68,14 +114,7 @@
68114
# Aggregate by state
69115
df_agg_weekly = df_agg.groupby(['epiweek', 'geo_loc_region', 'name']).agg({
70116
'pop_weighted_prevalence': 'sum',
71-
'collection_site_id': 'nunique',
72-
'sra_accession': 'nunique',
73-
'ww_population': 'mean',
74-
}).reset_index().rename(columns={
75-
'collection_site_id': 'num_sites',
76-
'sra_accession': 'num_samples',
77-
'ww_population': 'total_population'
78-
})
117+
}).reset_index()
79118

80119
df_agg_weekly['id'] = df_agg_weekly['epiweek'].astype(str) + '_' + df_agg_weekly['geo_loc_region']
81120
df_agg_weekly['total_lineage_prevalence'] = df_agg_weekly['id'].map(total_prev_state_dict)
@@ -90,14 +129,7 @@
90129
# Aggregate by epiweek and lineage
91130
df_region = df_region_data.groupby(['epiweek', 'name']).agg({
92131
'pop_weighted_prevalence': 'sum',
93-
'collection_site_id': 'nunique',
94-
'sra_accession': 'nunique',
95-
'ww_population': 'mean',
96-
}).reset_index().rename(columns={
97-
'collection_site_id': 'num_sites',
98-
'sra_accession': 'num_samples',
99-
'ww_population': 'total_population'
100-
})
132+
}).reset_index()
101133

102134
# Calculate proper weighted mean prevalence for region
103135
df_region['geo_loc_region'] = region
@@ -110,31 +142,35 @@
110142
# Aggregate by nation (USA)
111143
df_nation = df_agg.groupby(['epiweek', 'name']).agg({
112144
'pop_weighted_prevalence': 'sum',
113-
'collection_site_id': 'nunique',
114-
'sra_accession': 'nunique',
115-
'ww_population': 'mean',
116-
}).reset_index().rename(columns={
117-
'collection_site_id': 'num_sites',
118-
'sra_accession': 'num_samples',
119-
'ww_population': 'total_population'
120-
})
145+
}).reset_index()
121146

122147
# Calculate proper weighted mean prevalence for USA
123148
df_nation['geo_loc_region'] = 'USA'
124149
df_nation['id'] = df_nation['epiweek'].astype(str) + '_USA'
150+
df_nation['region_id'] = 'USA_' + df_nation['epiweek'].astype(str)
125151
df_nation['total_lineage_prevalence'] = df_nation['id'].map(total_prev_nation_dict)
126152
df_nation['mean_lineage_prevalence'] = df_nation['pop_weighted_prevalence'] / df_nation['total_lineage_prevalence']
153+
df_nation['total_population'] = df_nation['region_id'].map(population_dict)
154+
df_nation['num_samples'] = df_nation['region_id'].map(num_samples_dict)
155+
df_nation['num_sites'] = df_nation['region_id'].map(num_sites_dict)
127156

128157
# Combine all census regions with state data and national data
129158
df_region_combined = pd.concat(df_agg_census)
130159
df_agg_weekly = pd.concat([df_agg_weekly, df_region_combined, df_nation])
131-
df_agg_weekly['total_population'] = df_agg_weekly.groupby(['epiweek', 'geo_loc_region'])['total_population'].transform('mean') # Ensure total population is consistent across lineages in the same region
132160

133161
df_agg_weekly['id'] = df_agg_weekly['epiweek'].astype(str) + '_' + df_agg_weekly['geo_loc_region'] + '_' + df_agg_weekly['name']
162+
df_agg_weekly['region_id'] = df_agg_weekly['geo_loc_region'] + '_' + df_agg_weekly['epiweek'].astype(str)
134163
df_agg_weekly['crumbs'] = df_agg_weekly['name'].map(crumbs)
135164
df_agg_weekly['week_start'] = df_agg_weekly['epiweek'].apply(lambda x: x.startdate()).astype(str)
136165
df_agg_weekly['week_end'] = df_agg_weekly['epiweek'].apply(lambda x: x.enddate()).astype(str)
137166

167+
df_agg_weekly['total_population'] = df_agg_weekly['region_id'].map(population_dict)
168+
df_agg_weekly['num_samples'] = df_agg_weekly['region_id'].map(num_samples_dict)
169+
df_agg_weekly['num_sites'] = df_agg_weekly['region_id'].map(num_sites_dict)
170+
171+
print('california', df_agg_weekly[(df_agg_weekly['geo_loc_region'] == 'California') & (df_agg_weekly['epiweek'] == 202423)]['total_population'].value_counts())
172+
print('west', df_agg_weekly[(df_agg_weekly['geo_loc_region'] == 'West') & (df_agg_weekly['epiweek'] == 202423)]['total_population'].value_counts())
173+
138174
df_agg_weekly = df_agg_weekly[['id', 'epiweek', 'week_start', 'week_end', 'geo_loc_region', 'total_population', 'num_sites', 'num_samples', 'name', 'mean_lineage_prevalence', 'crumbs']]
139175

140176
# Workaround to save to json

0 commit comments

Comments
 (0)