Skip to content

Commit 8615a95

Browse files
committed
pygridtools#96 passing shared value as an argument to grid_map
1 parent 97b677e commit 8615a95

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

gridmap/job.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from email.mime.image import MIMEImage
5151
from io import open
5252
from importlib import import_module
53-
from multiprocessing import Pool
53+
from multiprocessing import Pool, Value
5454
from socket import gethostname, gethostbyname, getaddrinfo, getfqdn
5555
from smtplib import (SMTPRecipientsRefused, SMTPHeloError, SMTPSenderRefused,
5656
SMTPDataError)
@@ -730,15 +730,23 @@ def _execute(job):
730730
job.execute()
731731
return job.ret
732732

733+
def _init_pool_processes(the_val):
734+
'''Initialize each process with a global shared variable.
735+
'''
736+
global shared_val
737+
shared_val = the_val
738+
733739

734-
def _process_jobs_locally(jobs, max_processes=1):
740+
def _process_jobs_locally(jobs, max_processes=1, shared_val=None):
735741
"""
736742
Local execution using the package multiprocessing, if present
737743
738744
:param jobs: jobs to be executed
739745
:type jobs: list of Job
740746
:param max_processes: maximal number of processes
741747
:type max_processes: int
748+
:param shared_val: shared value for the jobs
749+
:type shared_val: multiprocessing.Value, optional
742750
743751
:return: list of jobs, each with return in job.ret
744752
:rtype: list of Job
@@ -751,7 +759,7 @@ def _process_jobs_locally(jobs, max_processes=1):
751759
for job in jobs:
752760
job.execute()
753761
else:
754-
pool = Pool(max_processes)
762+
pool = Pool(processes=max_processes, initializer=_init_pool_processes, initargs=(shared_val,))
755763
result = pool.map(_execute, jobs)
756764
for ret_val, job in zip(result, jobs):
757765
job.ret = ret_val
@@ -856,7 +864,7 @@ def _append_job_to_session(session, job, temp_dir=DEFAULT_TEMP_DIR, quiet=True):
856864

857865

858866
def process_jobs(jobs, temp_dir=DEFAULT_TEMP_DIR, white_list=None, quiet=True,
859-
max_processes=1, local=False, require_cluster=False):
867+
max_processes=1, local=False, require_cluster=False, shared_val=None):
860868
"""
861869
Take a list of jobs and process them on the cluster.
862870
@@ -879,6 +887,8 @@ def process_jobs(jobs, temp_dir=DEFAULT_TEMP_DIR, white_list=None, quiet=True,
879887
:param require_cluster: Should we raise an exception if access to cluster
880888
is not available?
881889
:type require_cluster: bool
890+
:param shared_val: A shared value for all of jobs
891+
:type shared_val: multiprocessing.Value, optional
882892
883893
:returns: List of Job results
884894
"""
@@ -904,7 +914,7 @@ def process_jobs(jobs, temp_dir=DEFAULT_TEMP_DIR, white_list=None, quiet=True,
904914
# handling of inputs, outputs and heartbeats
905915
monitor.check(sid, jobs)
906916
else:
907-
_process_jobs_locally(jobs, max_processes=max_processes)
917+
_process_jobs_locally(jobs, max_processes=max_processes, shared_val=shared_val)
908918

909919
return [job.ret for job in jobs]
910920

@@ -943,7 +953,7 @@ def grid_map(f, args_list, cleanup=True, mem_free="1G", name='gridmap_job',
943953
interpreting_shell=None, copy_env=True, add_env=None, project=None,
944954
validation_level=None, os_distribution=None, os_minor=None, gpu=0,
945955
h_vmem=None, h_rt=None, resources=None, completion_mail=False,
946-
require_cluster=False, par_env=DEFAULT_PAR_ENV):
956+
require_cluster=False, par_env=DEFAULT_PAR_ENV, shared_val=None):
947957
"""
948958
Maps a function onto the cluster.
949959
@@ -1016,6 +1026,8 @@ def grid_map(f, args_list, cleanup=True, mem_free="1G", name='gridmap_job',
10161026
:type os_distribution: str
10171027
:param os_minor: os minor version that need job to run on machine
10181028
:type os_minor: str
1029+
:param shared_val: A shared value for all the jobs
1030+
:type shared_val: multiprocessing.Value, optional
10191031
10201032
:returns: List of Job results
10211033
"""
@@ -1036,7 +1048,8 @@ def grid_map(f, args_list, cleanup=True, mem_free="1G", name='gridmap_job',
10361048
white_list=white_list,
10371049
quiet=quiet, local=local,
10381050
max_processes=max_processes,
1039-
require_cluster=require_cluster)
1051+
require_cluster=require_cluster,
1052+
shared_val=shared_val)
10401053

10411054
# send a completion mail (if requested and configured)
10421055
if completion_mail and SEND_ERROR_MAIL:

0 commit comments

Comments
 (0)