Skip to content

Commit f37d199

Browse files
committed
Support stdin channel
1 parent 780b85c commit f37d199

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

plugins/kernels/fps_kernels/kernel_server/connect.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,19 +87,23 @@ async def launch_kernel(
8787
return p
8888

8989

90-
def create_socket(channel: str, cfg: cfg_t) -> Socket:
90+
def create_socket(channel: str, cfg: cfg_t, identity: Optional[bytes] = None) -> Socket:
9191
ip = cfg["ip"]
9292
port = cfg[f"{channel}_port"]
9393
url = f"tcp://{ip}:{port}"
9494
socket_type = channel_socket_types[channel]
9595
sock = context.socket(socket_type)
9696
sock.linger = 1000 # set linger to 1s to prevent hangs at exit
97+
if identity:
98+
sock.identity = identity
9799
sock.connect(url)
98100
return sock
99101

100102

101-
def connect_channel(channel_name: str, cfg: cfg_t) -> Socket:
102-
sock = create_socket(channel_name, cfg)
103+
def connect_channel(
104+
channel_name: str, cfg: cfg_t, identity: Optional[bytes] = None
105+
) -> Socket:
106+
sock = create_socket(channel_name, cfg, identity)
103107
if channel_name == "iopub":
104108
sock.setsockopt(zmq.SUBSCRIBE, b"")
105109
return sock

plugins/kernels/fps_kernels/kernel_server/server.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import uuid
23
import json
34
import asyncio
45
import signal
@@ -105,12 +106,21 @@ async def start(self) -> None:
105106
self.kernelspec_path, self.connection_file_path, self.capture_kernel_output
106107
)
107108
assert self.connection_cfg is not None
108-
self.shell_channel = connect_channel("shell", self.connection_cfg)
109-
self.control_channel = connect_channel("control", self.connection_cfg)
109+
identity = uuid.uuid4().hex.encode("ascii")
110+
self.shell_channel = connect_channel(
111+
"shell", self.connection_cfg, identity=identity
112+
)
113+
self.stdin_channel = connect_channel(
114+
"stdin", self.connection_cfg, identity=identity
115+
)
116+
self.control_channel = connect_channel(
117+
"control", self.connection_cfg, identity=identity
118+
)
110119
self.iopub_channel = connect_channel("iopub", self.connection_cfg)
111120
await self._wait_for_ready()
112121
self.channel_tasks += [
113122
asyncio.create_task(self.listen("shell")),
123+
asyncio.create_task(self.listen("stdin")),
114124
asyncio.create_task(self.listen("control")),
115125
asyncio.create_task(self.listen("iopub")),
116126
]
@@ -154,6 +164,8 @@ async def listen(self, channel_name: str):
154164
channel = self.control_channel
155165
elif channel_name == "iopub":
156166
channel = self.iopub_channel
167+
elif channel_name == "stdin":
168+
channel = self.stdin_channel
157169

158170
while True:
159171
parts = await get_zmq_parts(channel)
@@ -196,6 +208,8 @@ async def send_to_zmq(self, websocket):
196208
send_message(msg, self.shell_channel, self.key)
197209
elif channel == "control":
198210
send_message(msg, self.control_channel, self.key)
211+
elif channel == "stdin":
212+
send_message(msg, self.stdin_channel, self.key)
199213
elif websocket.accepted_subprotocol == "v1.kernel.websocket.jupyter.org":
200214
while True:
201215
msg = await websocket.websocket.receive_bytes()
@@ -213,6 +227,8 @@ async def send_to_zmq(self, websocket):
213227
send_raw_message(parts, self.shell_channel, self.key)
214228
elif channel == "control":
215229
send_raw_message(parts, self.control_channel, self.key)
230+
elif channel == "stdin":
231+
send_raw_message(parts, self.stdin_channel, self.key)
216232

217233
async def send_to_ws(self, websocket, parts, parent_header, channel_name):
218234
if not websocket.accepted_subprotocol:

0 commit comments

Comments
 (0)