Skip to content

Commit 1575ecf

Browse files
author
Arian Jamasb
committed
reset index after handling altlocs #384
1 parent f8b5ef5 commit 1575ecf

File tree

1 file changed

+24
-73
lines changed

1 file changed

+24
-73
lines changed

graphein/protein/graphs.py

+24-73
Original file line numberDiff line numberDiff line change
@@ -96,18 +96,12 @@ def read_pdb_to_dataframe(
9696
:rtype: pd.DataFrame
9797
"""
9898
if pdb_code is None and path is None and uniprot_id is None:
99-
raise NameError(
100-
"One of pdb_code, path or uniprot_id must be specified!"
101-
)
99+
raise NameError("One of pdb_code, path or uniprot_id must be specified!")
102100

103101
if path is not None:
104102
if isinstance(path, Path):
105103
path = os.fsdecode(path)
106-
if (
107-
path.endswith(".pdb")
108-
or path.endswith(".pdb.gz")
109-
or path.endswith(".ent")
110-
):
104+
if path.endswith(".pdb") or path.endswith(".pdb.gz") or path.endswith(".ent"):
111105
atomic_df = PandasPdb().read_pdb(path)
112106
elif path.endswith(".mmtf") or path.endswith(".mmtf.gz"):
113107
atomic_df = PandasMmtf().read_mmtf(path)
@@ -116,9 +110,7 @@ def read_pdb_to_dataframe(
116110
f"File {path} must be either .pdb(.gz), .mmtf(.gz) or .ent, not {path.split('.')[-1]}"
117111
)
118112
elif uniprot_id is not None:
119-
atomic_df = PandasPdb().fetch_pdb(
120-
uniprot_id=uniprot_id, source="alphafold2-v3"
121-
)
113+
atomic_df = PandasPdb().fetch_pdb(uniprot_id=uniprot_id, source="alphafold2-v3")
122114
else:
123115
atomic_df = PandasPdb().fetch_pdb(pdb_code)
124116

@@ -172,11 +164,7 @@ def label_node_id(
172164
df["node_id"] = df["node_id"] + ":" + df["atom_name"]
173165
elif granularity in {"rna_atom", "rna_centroid"}:
174166
df["node_id"] = (
175-
df["node_id"]
176-
+ ":"
177-
+ df["atom_number"].apply(str)
178-
+ ":"
179-
+ df["atom_name"]
167+
df["node_id"] + ":" + df["atom_number"].apply(str) + ":" + df["atom_name"]
180168
)
181169
return df
182170

@@ -189,9 +177,7 @@ def deprotonate_structure(df: pd.DataFrame) -> pd.DataFrame:
189177
:returns: Atomic dataframe with all ``element_symbol == "H" or "D" or "T"`` removed.
190178
:rtype: pd.DataFrame
191179
"""
192-
log.debug(
193-
"Deprotonating protein. This removes H atoms from the pdb_df dataframe"
194-
)
180+
log.debug("Deprotonating protein. This removes H atoms from the pdb_df dataframe")
195181
return filter_dataframe(
196182
df,
197183
by_column="element_symbol",
@@ -225,9 +211,7 @@ def convert_structure_to_centroids(df: pd.DataFrame) -> pd.DataFrame:
225211
return df
226212

227213

228-
def subset_structure_to_atom_type(
229-
df: pd.DataFrame, granularity: str
230-
) -> pd.DataFrame:
214+
def subset_structure_to_atom_type(df: pd.DataFrame, granularity: str) -> pd.DataFrame:
231215
"""
232216
Return a subset of atomic dataframe that contains only certain atom names.
233217
@@ -241,9 +225,7 @@ def subset_structure_to_atom_type(
241225
)
242226

243227

244-
def remove_alt_locs(
245-
df: pd.DataFrame, keep: str = "max_occupancy"
246-
) -> pd.DataFrame:
228+
def remove_alt_locs(df: pd.DataFrame, keep: str = "max_occupancy") -> pd.DataFrame:
247229
"""
248230
This function removes alternatively located atoms from PDB DataFrames
249231
(see https://proteopedia.org/wiki/index.php/Alternate_locations). Among the
@@ -277,7 +259,7 @@ def remove_alt_locs(
277259
# Unsort
278260
if keep in ["max_occupancy", "min_occupancy"]:
279261
df = df.sort_index()
280-
262+
df = df.reset_index(drop=True)
281263
return df
282264

283265

@@ -307,9 +289,7 @@ def remove_insertions(
307289
)
308290

309291

310-
def filter_hetatms(
311-
df: pd.DataFrame, keep_hets: List[str]
312-
) -> List[pd.DataFrame]:
292+
def filter_hetatms(df: pd.DataFrame, keep_hets: List[str]) -> List[pd.DataFrame]:
313293
"""Return hetatms of interest.
314294
315295
:param df: Protein Structure dataframe to filter hetatoms from.
@@ -454,9 +434,7 @@ def sort_dataframe(df: pd.DataFrame) -> pd.DataFrame:
454434
:return: Sorted protein dataframe.
455435
:rtype: pd.DataFrame
456436
"""
457-
return df.sort_values(
458-
by=["chain_id", "residue_number", "atom_number", "insertion"]
459-
)
437+
return df.sort_values(by=["chain_id", "residue_number", "atom_number", "insertion"])
460438

461439

462440
def select_chains(
@@ -558,8 +536,7 @@ def initialise_graph_with_metadata(
558536
elif granularity == "atom":
559537
sequence = (
560538
protein_df.loc[
561-
(protein_df["chain_id"] == c)
562-
& (protein_df["atom_name"] == "CA")
539+
(protein_df["chain_id"] == c) & (protein_df["atom_name"] == "CA")
563540
]["residue_name"]
564541
.apply(three_to_one_with_mods)
565542
.str.cat()
@@ -610,13 +587,9 @@ def add_nodes_to_graph(
610587
# Set intrinsic node attributes
611588
nx.set_node_attributes(G, dict(zip(nodes, chain_id)), "chain_id")
612589
nx.set_node_attributes(G, dict(zip(nodes, residue_name)), "residue_name")
613-
nx.set_node_attributes(
614-
G, dict(zip(nodes, residue_number)), "residue_number"
615-
)
590+
nx.set_node_attributes(G, dict(zip(nodes, residue_number)), "residue_number")
616591
nx.set_node_attributes(G, dict(zip(nodes, atom_type)), "atom_type")
617-
nx.set_node_attributes(
618-
G, dict(zip(nodes, element_symbol)), "element_symbol"
619-
)
592+
nx.set_node_attributes(G, dict(zip(nodes, element_symbol)), "element_symbol")
620593
nx.set_node_attributes(G, dict(zip(nodes, coords)), "coords")
621594
nx.set_node_attributes(G, dict(zip(nodes, b_factor)), "b_factor")
622595

@@ -642,9 +615,7 @@ def calculate_centroid_positions(
642615
:rtype: pd.DataFrame
643616
"""
644617
centroids = (
645-
atoms.groupby(
646-
["residue_number", "chain_id", "residue_name", "insertion"]
647-
)
618+
atoms.groupby(["residue_number", "chain_id", "residue_name", "insertion"])
648619
.mean(numeric_only=True)[["x_coord", "y_coord", "z_coord"]]
649620
.reset_index()
650621
)
@@ -902,13 +873,9 @@ def _mp_graph_constructor(
902873
func = partial(construct_graph, config=config)
903874
try:
904875
if source == "pdb_code":
905-
return func(
906-
pdb_code=args[0], chain_selection=args[1], model_index=args[2]
907-
)
876+
return func(pdb_code=args[0], chain_selection=args[1], model_index=args[2])
908877
elif source == "path":
909-
return func(
910-
path=args[0], chain_selection=args[1], model_index=args[2]
911-
)
878+
return func(path=args[0], chain_selection=args[1], model_index=args[2])
912879
elif source == "uniprot_id":
913880
return func(
914881
uniprot_id=args[0],
@@ -1004,9 +971,7 @@ def construct_graphs_mp(
1004971
)
1005972
if out_path is not None:
1006973
[
1007-
nx.write_gpickle(
1008-
g, str(f"{out_path}/" + f"{g.graph['name']}.pickle")
1009-
)
974+
nx.write_gpickle(g, str(f"{out_path}/" + f"{g.graph['name']}.pickle"))
1010975
for g in graphs
1011976
]
1012977

@@ -1070,15 +1035,11 @@ def compute_chain_graph(
10701035

10711036
# Add edges
10721037
for u, v, d in g.edges(data=True):
1073-
h.add_edge(
1074-
g.nodes[u]["chain_id"], g.nodes[v]["chain_id"], kind=d["kind"]
1075-
)
1038+
h.add_edge(g.nodes[u]["chain_id"], g.nodes[v]["chain_id"], kind=d["kind"])
10761039
# Remove self-loops if necessary. Checks for equality between nodes in a
10771040
# given edge.
10781041
if remove_self_loops:
1079-
edges_to_remove: List[Tuple[str]] = [
1080-
(u, v) for u, v in h.edges() if u == v
1081-
]
1042+
edges_to_remove: List[Tuple[str]] = [(u, v) for u, v in h.edges() if u == v]
10821043
h.remove_edges_from(edges_to_remove)
10831044

10841045
# Compute a weighted graph if required.
@@ -1181,16 +1142,10 @@ def compute_secondary_structure_graph(
11811142
ss_list = ss_list[~ss_list.str.contains("-")]
11821143
# Subset to only allowable SS elements if necessary
11831144
if allowable_ss_elements:
1184-
ss_list = ss_list[
1185-
ss_list.str.contains("|".join(allowable_ss_elements))
1186-
]
1145+
ss_list = ss_list[ss_list.str.contains("|".join(allowable_ss_elements))]
11871146

1188-
constituent_residues: Dict[str, List[str]] = ss_list.index.groupby(
1189-
ss_list.values
1190-
)
1191-
constituent_residues = {
1192-
k: list(v) for k, v in constituent_residues.items()
1193-
}
1147+
constituent_residues: Dict[str, List[str]] = ss_list.index.groupby(ss_list.values)
1148+
constituent_residues = {k: list(v) for k, v in constituent_residues.items()}
11941149
residue_counts: Dict[str, int] = ss_list.groupby(ss_list).count().to_dict()
11951150

11961151
# Add Nodes from secondary structure list
@@ -1209,9 +1164,7 @@ def compute_secondary_structure_graph(
12091164
# Iterate over edges in source graph and add SS-SS edges to new graph.
12101165
for u, v, d in g.edges(data=True):
12111166
try:
1212-
h.add_edge(
1213-
ss_list[u], ss_list[v], kind=d["kind"], source=f"{u}_{v}"
1214-
)
1167+
h.add_edge(ss_list[u], ss_list[v], kind=d["kind"], source=f"{u}_{v}")
12151168
except KeyError as e:
12161169
log.debug(
12171170
f"Edge {u}-{v} not added to secondary structure graph. \
@@ -1221,9 +1174,7 @@ def compute_secondary_structure_graph(
12211174
# Remove self-loops if necessary.
12221175
# Checks for equality between nodes in a given edge.
12231176
if remove_self_loops:
1224-
edges_to_remove: List[Tuple[str]] = [
1225-
(u, v) for u, v in h.edges() if u == v
1226-
]
1177+
edges_to_remove: List[Tuple[str]] = [(u, v) for u, v in h.edges() if u == v]
12271178
h.remove_edges_from(edges_to_remove)
12281179

12291180
# Create weighted graph from h

0 commit comments

Comments
 (0)