Skip to content

Commit eee9ad6

Browse files
committed
Overload the intersection and union operators
1 parent 4aac40b commit eee9ad6

File tree

1 file changed

+35
-3
lines changed

1 file changed

+35
-3
lines changed

openpmd_viewer/openpmd_timeseries/particle_tracker.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class ParticleTracker( object ):
3939
to be stored in the openPMD files.
4040
"""
4141

42-
def __init__(self, ts, species=None, t=None,
42+
def __init__(self, ts=None, species=None, t=None,
4343
iteration=None, select=None, preserve_particle_index=False):
4444
"""
4545
Initialize an instance of `ParticleTracker`: select particles at
@@ -69,8 +69,9 @@ def __init__(self, ts, species=None, t=None,
6969
'x' : [-4., 10.] (Particles having x between -4 and 10)
7070
'ux' : [-0.1, 0.1] (Particles having ux between -0.1 and 0.1 mc)
7171
'uz' : [5., None] (Particles with uz above 5 mc).
72-
Can also be a 1d array of interegers corresponding to the
73-
selected particles `id`
72+
Can also be a 1d array of integers corresponding to the
73+
selected particles `id`. In this case, the arguments `ts`, `t`
74+
and `iteration` do not need to be passed.
7475
7576
preserve_particle_index: bool, optional
7677
When retrieving particles at a several iterations,
@@ -105,6 +106,37 @@ def __init__(self, ts, species=None, t=None,
105106
self.species = species
106107
self.preserve_particle_index = preserve_particle_index
107108

109+
def __and__(self, other):
110+
"""
111+
Define the intersection of two ParticleTracker instances.
112+
113+
This selects the particles that are present in both instances.
114+
"""
115+
# Check that both instances are consistent
116+
assert self.species == other.species
117+
assert self.preserve_particle_index == other.preserve_particle_index
118+
119+
# Find the intersection of the selected particles
120+
pid = np.intersect1d( self.selected_pid, other.selected_pid )
121+
pt = ParticleTracker( species=self.species, select=pid,
122+
preserve_particle_index=self.preserve_particle_index )
123+
return pt
124+
125+
def __or__(self, other):
126+
"""
127+
Define the union of two ParticleTracker instances.
128+
129+
This selects the particles that are present in at least one of the instances.
130+
"""
131+
# Check that both instances are consistent
132+
assert self.species == other.species
133+
assert self.preserve_particle_index == other.preserve_particle_index
134+
135+
# Find the union of the selected particles
136+
pid = np.union1d( self.selected_pid, other.selected_pid )
137+
pt = ParticleTracker( species=self.species, select=pid,
138+
preserve_particle_index=self.preserve_particle_index )
139+
return pt
108140

109141
def extract_tracked_particles( self, iteration, data_reader, data_list,
110142
species, extensions ):

0 commit comments

Comments
 (0)