Skip to content

allow client to reconnect if other side closes connection #478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
117 changes: 104 additions & 13 deletions lib/thrift/binary/framed/client.ex
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ defmodule Thrift.Binary.Framed.Client do
alias Thrift.TApplicationException
alias Thrift.Transport.SSL

@immutable_tcp_opts [active: false, packet: 4, mode: :binary]
@immutable_tcp_opts [active: true, packet: 4, mode: :binary]

@type error :: {:error, atom} | {:error, {:exception, struct}}
@type success :: {:ok, binary}
Expand All @@ -42,6 +42,7 @@ defmodule Thrift.Binary.Framed.Client do
{:tcp_opts, [tcp_option]}
| {:ssl_opts, [SSL.option()]}
| {:gen_server_opts, [genserver_call_option]}
| {:reconnect, boolean}

@type options :: [option]

Expand All @@ -55,7 +56,8 @@ defmodule Thrift.Binary.Framed.Client do
ssl_opts: [SSL.option()],
timeout: integer,
sock: {:gen_tcp, :gen_tcp.socket()} | {:ssl, :ssl.sslsocket()},
seq_id: integer
seq_id: integer,
reconnect: boolean
}
defstruct host: nil,
port: nil,
Expand All @@ -64,7 +66,8 @@ defmodule Thrift.Binary.Framed.Client do
ssl_opts: nil,
timeout: 5000,
sock: nil,
seq_id: 0
seq_id: 0,
reconnect: false
end

require Logger
Expand All @@ -74,6 +77,7 @@ defmodule Thrift.Binary.Framed.Client do
def init({host, port, opts}) do
tcp_opts = Keyword.get(opts, :tcp_opts, [])
ssl_opts = Keyword.get(opts, :ssl_opts, [])
reconnect = Keyword.get(opts, :reconnect, false)

{timeout, tcp_opts} = Keyword.pop(tcp_opts, :timeout, 5000)

Expand All @@ -82,7 +86,8 @@ defmodule Thrift.Binary.Framed.Client do
port: port,
tcp_opts: tcp_opts,
ssl_opts: ssl_opts,
timeout: timeout
timeout: timeout,
reconnect: reconnect
}

{:connect, :init, s}
Expand Down Expand Up @@ -137,15 +142,20 @@ defmodule Thrift.Binary.Framed.Client do
def close(conn), do: Connection.call(conn, :close)

@impl Connection
def connect(_info, %{sock: nil, host: host, port: port, tcp_opts: opts, timeout: timeout} = s) do
def connect(info, %{sock: nil, host: host, port: port, tcp_opts: opts, timeout: timeout} = s) do
opts =
opts
|> Keyword.merge(@immutable_tcp_opts)
|> Keyword.put_new(:send_timeout, 1000)

# reset sequence id for newly created connection
s = %{s | seq_id: 0}

case :gen_tcp.connect(host, port, opts, timeout) do
{:ok, sock} ->
maybe_ssl_handshake(sock, host, port, s)
sock
|> maybe_ssl_handshake(host, port, s)
|> maybe_resend_data(info)

{:error, :timeout} = error ->
Logger.error("Failed to connect to #{host}:#{port} due to timeout after #{timeout}ms")
Expand All @@ -158,10 +168,13 @@ defmodule Thrift.Binary.Framed.Client do
end

@impl Connection
def disconnect(info, %{sock: {transport, sock}}) do
def disconnect(info, %{sock: {transport, sock}} = s) do
:ok = transport.close(sock)

case info do
{:reconnect, _} ->
{:connect, info, %{s | sock: nil}}

{:close, from} ->
Connection.reply(from, :ok)
{:stop, :normal, nil}
Expand Down Expand Up @@ -244,19 +257,31 @@ defmodule Thrift.Binary.Framed.Client do
end

def handle_call(
{:call, rpc_name, serialized_args, tcp_opts},
_,
%{sock: {transport, sock}, seq_id: seq_id, timeout: default_timeout} = s
{:call, rpc_name, serialized_args, tcp_opts} = msg,
from,
%{
sock: {transport, sock},
seq_id: seq_id,
timeout: default_timeout,
reconnect: reconnect
} = s
) do
s = %{s | seq_id: seq_id + 1}
message = Binary.serialize(:message_begin, {:call, seq_id, rpc_name})
timeout = Keyword.get(tcp_opts, :timeout, default_timeout)

with :ok <- transport.send(sock, [message | serialized_args]),
{:ok, message} <- transport.recv(sock, 0, timeout) do
{:ok, message} <- receive_message(transport, sock, timeout) do
reply = deserialize_message_reply(message, rpc_name, seq_id)
{:reply, reply, s}
else
{:error, :closed} = error ->
if reconnect do
{:disconnect, {:reconnect, {:call, msg, from}}, s}
else
{:disconnect, error, error, s}
end

{:error, :timeout} = error ->
{:disconnect, {:error, :timeout, timeout}, error, s}

Expand All @@ -276,8 +301,8 @@ defmodule Thrift.Binary.Framed.Client do
end

def handle_cast(
{:oneway, rpc_name, serialized_args},
%{sock: {transport, sock}, seq_id: seq_id} = s
{:oneway, rpc_name, serialized_args} = msg,
%{sock: {transport, sock}, seq_id: seq_id, reconnect: reconnect} = s
) do
s = %{s | seq_id: seq_id + 1}
message = Binary.serialize(:message_begin, {:oneway, seq_id, rpc_name})
Expand All @@ -286,15 +311,51 @@ defmodule Thrift.Binary.Framed.Client do
:ok ->
{:noreply, s}

{:error, :closed} = error ->
if reconnect do
{:disconnect, {:reconnect, {:cast, msg}}, s}
else
{:disconnect, error, s}
end

{:error, _} = error ->
{:disconnect, error, s}
end
end

@impl Connection
def handle_info({:tcp_closed, sock}, %{reconnect: true, sock: {_transport, sock}} = s) do
{:disconnect, {:reconnect, nil}, s}
end

def handle_info(_, s) do
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should handle the {:tcp_error, sock, _}, {:ssl_error, sock, _}, {:tcp, sock, _}, {;ssl, sock, _} here but ignore messages from old sockets (or clean them up in disconnect/2).

{:noreply, s}
end

def deserialize_message_reply(message, rpc_name, seq_id) do
handle_message(Binary.deserialize(:message_begin, message), seq_id, rpc_name)
end

defp receive_message(:gen_tcp, sock, timeout) do
receive do
{:tcp, ^sock, data} -> {:ok, data}
{:tcp_closed, ^sock} -> {:error, :closed}
{:tcp_error, ^sock, error} -> {:error, error}
after
timeout -> {:error, :timeout}
end
end

defp receive_message(:ssl, sock, timeout) do
receive do
{:ssl, ^sock, data} -> {:ok, data}
{:ssl_closed, ^sock} -> {:error, :closed}
{:ssl_error, ^sock, error} -> {:error, error}
after
timeout -> {:error, :timeout}
end
end

defp handle_message({:ok, {:reply, seq_id, rpc_name, serialized_response}}, seq_id, rpc_name) do
{:ok, serialized_response}
end
Expand Down Expand Up @@ -372,4 +433,34 @@ defmodule Thrift.Binary.Framed.Client do
{:stop, error, s}
end
end

defp maybe_resend_data({:ok, s}, {:reconnect, {:call, msg, from}}) do
case handle_call(msg, from, s) do
{:reply, reply, s} ->
GenServer.reply(from, reply)
{:ok, s}

{:disconnect, info, error, s} ->
GenServer.reply(from, error)
disconnect(info, s)

_ ->
{:ok, s}
end
end

defp maybe_resend_data({:ok, s}, {:reconnect, {:cast, msg}}) do
case handle_cast(msg, s) do
{:noreply, s} ->
{:ok, s}

{:disconnect, info, s} ->
disconnect(info, s)

_ ->
{:ok, s}
end
end

defp maybe_resend_data(reply, _), do: reply
end
26 changes: 26 additions & 0 deletions test/thrift/binary/framed/server_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,30 @@ defmodule Servers.Binary.Framed.IntegrationTest do
thrift_test "client methods can be called by name instead of pid", %{client_name: name} do
assert {:ok, true} == Client.ping(name)
end

@ping_reply <<128, 1, 0, 2, 0, 0, 0, 4, 112, 105, 110, 103, 0, 0, 0, 0, 2, 0, 0, 1, 0>>
thrift_test "client can reconnect when connection closed by server", ctx do
{:ok, sock} = :gen_tcp.listen(0, [:binary, packet: 4, active: false])
{:ok, port} = :inet.port(sock)

first_conn =
Task.async(fn ->
{:ok, conn} = :gen_tcp.accept(sock)
:ok = :gen_tcp.close(conn)
end)

name = String.to_atom("#{ctx.client_name}_1")
{:ok, client} = Client.start_link("localhost", port, name: name, reconnect: true)

second_conn =
Task.async(fn ->
{:ok, conn} = :gen_tcp.accept(sock)
{:ok, _} = :gen_tcp.recv(conn, 0)
:ok = :gen_tcp.send(conn, @ping_reply)
end)

assert {:ok, true} == Client.ping(client)
Task.await(first_conn)
Task.await(second_conn)
end
end