Skip to content

Commit 9c78d86

Browse files
committed
fix optionals
1 parent e245624 commit 9c78d86

File tree

3 files changed

+44
-5
lines changed

3 files changed

+44
-5
lines changed

ipython2cwl/cwltoolextractor.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import astor
1212
import nbconvert
1313
import yaml
14+
from astor.string_repr import pretty_string
1415
from nbformat.notebooknode import NotebookNode
1516

1617
from .iotypes import CWLFilePathInput, CWLBooleanInput, CWLIntInput, CWLStringInput, CWLFilePathOutput
@@ -78,7 +79,7 @@ def visit_AnnAssign(self, node):
7879
if annotation in self.input_type_mapper:
7980
mapper = self.input_type_mapper[annotation]
8081
self.extracted_nodes.append(
81-
(node, mapper[0], mapper[1], True, True, False)
82+
(node, mapper[0], mapper[1], not mapper[0].endswith('?'), True, False)
8283
)
8384
return None
8485

@@ -176,18 +177,24 @@ def _wrap_script_to_method(cls, tree, variables) -> str:
176177
main_function = ast.parse(main_template_code)
177178
[node for node in main_function.body if isinstance(node, ast.FunctionDef) and node.name == 'main'][0] \
178179
.body = tree.body
179-
return astor.to_source(main_function)
180+
return astor.to_source(
181+
main_function,
182+
pretty_string=lambda s, embedded, current_line, uni: pretty_string(s, embedded, current_line, uni, max_line=500)
183+
)
180184

181185
@classmethod
182186
def __get_add_arguments__(cls, variables):
183187
args = []
184188
for variable in variables:
185189
is_array = variable.cwl_typeof.endswith('[]')
190+
is_optional = variable.cwl_typeof.endswith('?')
186191
arg: str = f'parser.add_argument("--{variable.name}", '
187192
arg += f'type={variable.argparse_typeof}, '
188193
arg += f'required={variable.required}, '
189194
if is_array:
190195
arg += f'nargs="+", '
196+
if is_optional:
197+
arg += f'default=None, '
191198
arg = arg.strip()
192199
arg += ')'
193200
args.append(arg)

tests/simple.ipynb

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"outputs": [],
88
"source": [
99
"import pandas as pd\n",
10-
"from typing import List\n",
10+
"from typing import List, Optional\n",
1111
"import matplotlib\n",
1212
"from ipython2cwl.iotypes import CWLFilePathInput, CWLFilePathOutput, CWLStringInput"
1313
]
@@ -19,7 +19,8 @@
1919
"outputs": [],
2020
"source": [
2121
"dataset: CWLFilePathInput = './data/data.csv'\n",
22-
"messages: List[CWLStringInput] = ['hello', 'world']"
22+
"messages: List[CWLStringInput] = ['hello', 'world']\n",
23+
"optional_message: Optional[CWLStringInput] = \"Hello from optional\""
2324
]
2425
},
2526
{
@@ -60,6 +61,16 @@
6061
"with open(messages_filename, 'w') as f:\n",
6162
" f.write(' '.join(messages))"
6263
]
64+
},
65+
{
66+
"cell_type": "code",
67+
"execution_count": null,
68+
"metadata": {},
69+
"outputs": [],
70+
"source": [
71+
"if optional_message is not None:\n",
72+
" print(optional_message)"
73+
]
6374
}
6475
],
6576
"metadata": {

tests/test_ipython2cwl_from_repo.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ def test_docker_build(self):
4040
docker_client = docker.from_env()
4141
script = docker_client.containers.run(dockerfile_image_id, '/app/cwl/bin/simple', entrypoint='/bin/cat')
4242
self.assertIn('fig.figure.savefig(after_transform_data)', script.decode())
43-
messages_array_arg_line = ast.parse([line.strip() for line in script.decode().splitlines() if '--messages' in line][-1])
43+
messages_array_arg_line = ast.parse(
44+
[line.strip() for line in script.decode().splitlines() if '--messages' in line][-1]
45+
)
4446
self.assertEqual(
4547
'+', # nargs = '+'
4648
[k.value.s for k in messages_array_arg_line.body[0].value.keywords if k.arg == 'nargs'][0]
@@ -49,6 +51,19 @@ def test_docker_build(self):
4951
'str', # type = 'str'
5052
[k.value.id for k in messages_array_arg_line.body[0].value.keywords if k.arg == 'type'][0]
5153
)
54+
55+
script_tree = ast.parse(script.decode())
56+
optional_expression = [x for x in script_tree.body[-1].body if
57+
isinstance(x, ast.Expr) and isinstance(x.value, ast.Call) and len(x.value.args) > 0 and
58+
x.value.args[0].s == '--optional_message'][0]
59+
self.assertEqual(
60+
False,
61+
[k.value for k in optional_expression.value.keywords if k.arg == 'required'][0].value
62+
)
63+
self.assertEqual(
64+
None,
65+
[k.value for k in optional_expression.value.keywords if k.arg == 'default'][0].value
66+
)
5267
self.assertDictEqual(
5368
{
5469
'cwlVersion': "v1.1",
@@ -70,6 +85,12 @@ def test_docker_build(self):
7085
'inputBinding': {
7186
'prefix': '--messages'
7287
}
88+
},
89+
'optional_message': {
90+
'type': 'string?',
91+
'inputBinding': {
92+
'prefix': '--optional_message'
93+
}
7394
}
7495
},
7596
'outputs': {

0 commit comments

Comments
 (0)