Skip to content

Feature/158-TUS-simulator #168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 30 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
5a49833
Create class to represent an information need
NoB0 May 13, 2024
634b2c7
Add tests for information need
NoB0 May 13, 2024
ea422e5
Fix version rasa
NoB0 May 13, 2024
025da95
Update requirements to reduce backtracking
NoB0 May 13, 2024
3e7e45f
Update test
NoB0 May 13, 2024
e4315ec
Format with black
NoB0 May 13, 2024
a1a82b9
Format docstring
NoB0 May 13, 2024
3ac5192
Add dialogue state and dialogue state tracker
NoB0 May 29, 2024
b7208c5
Add tests for DST
NoB0 May 29, 2024
1d2dacd
Black
NoB0 May 29, 2024
41e6fba
Implement transformer model
NoB0 May 29, 2024
853cfbb
Fix pre-commit
NoB0 May 29, 2024
11247e7
Merge branch 'feature/TUS-feature-handler' into feature/158-TUS-featu…
NoB0 May 29, 2024
32506a1
Merge branch 'feature/core-neural-us' into feature/158-TUS-feature-ha…
NoB0 May 29, 2024
7b8ba16
Add TUS feature handler
NoB0 May 29, 2024
9ec13e4
Merge branch 'feature/163-Add-dialogue-management-components' into fe…
NoB0 May 29, 2024
b511be1
Update tests
NoB0 May 29, 2024
7daa047
Black
NoB0 May 29, 2024
763008b
Update feature handler
NoB0 May 30, 2024
bb0f715
Merge branch 'main' into feature/158-TUS-feature-handler
NoB0 Jun 4, 2024
6901be4
Reorganize modules
NoB0 Jun 4, 2024
381fd32
Update feature handler
NoB0 Jun 25, 2024
8dcb725
Add missing requirement
NoB0 Jun 25, 2024
554b826
Add TUS
NoB0 Jun 25, 2024
962375d
Merge branch 'main' into feature/158-TUS-simulator
NoB0 Aug 14, 2024
4668e01
Remove duplicate imports
NoB0 Aug 14, 2024
a0d48fa
Update response generation in TUS
NoB0 Aug 14, 2024
a0911ac
Merge branch 'main' into feature/158-TUS-simulator
NoB0 Aug 21, 2024
42a7aca
Update `_get_label`
NoB0 Aug 21, 2024
5ede25e
Merge branch 'main' into feature/158-TUS-simulator
NoB0 Oct 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions usersimcrs/simulator/neural/core/transformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
"""Encoder-only transformer model for neural user simulator."""
"""Encoder-only transformer model for neural user simulator.

Implementation inspired by PyTorch documentation and TUS's transformer model.

Sources:
https://colab.research.google.com/github/pytorch/tutorials/blob/gh-pages/_downloads/dca13261bbb4e9809d1a3aa521d22dd7/transformer_tutorial.ipynb#scrollTo=R8veciavth40
https://gitlab.cs.uni-duesseldorf.de/general/dsml/tus_public/-/blob/master/convlab2/policy/tus/multiwoz/transformer.py?ref_type=heads
"""

import math

Expand Down Expand Up @@ -54,7 +61,6 @@ def __init__(
nhead: int,
hidden_dim: int,
num_encoder_layers: int,
num_token: int,
dropout: float = 0.5,
) -> None:
"""Initializes a encoder-only transformer model.
Expand All @@ -69,15 +75,15 @@ def __init__(
dropout: Dropout rate. Defaults to 0.5.
"""
super(TransformerEncoderModel, self).__init__()
self.d_model = input_dim
self.d_model = hidden_dim

self.pos_encoder = PositionalEncoding(input_dim, dropout)
self.embedding = nn.Embedding(num_token, input_dim)
self.pos_encoder = PositionalEncoding(hidden_dim, dropout)
self.embedding = nn.Linear(input_dim, hidden_dim)

# Encoder layers
norm_layer = nn.LayerNorm(input_dim)
norm_layer = nn.LayerNorm(hidden_dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=input_dim,
d_model=self.d_model,
nhead=nhead,
dim_feedforward=hidden_dim,
)
Expand All @@ -87,7 +93,7 @@ def __init__(
norm=norm_layer,
)

self.linear = nn.Linear(input_dim, output_dim)
self.linear = nn.Linear(hidden_dim, output_dim)
self.softmax = nn.Softmax(dim=-1)

self.init_weights()
Expand All @@ -113,6 +119,8 @@ def forward(
"""
src = self.embedding(src) * math.sqrt(self.d_model)
src = self.pos_encoder(src)
output = self.encoder(src, mask=src_mask)
src = src.permute(1, 0, 2)
output = self.encoder(src, src_key_padding_mask=src_mask)
output = self.linear(output)
output = output.permute(1, 0, 2)
return output
235 changes: 235 additions & 0 deletions usersimcrs/simulator/neural/tus/tus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
"""Transformer-based User Simulator (TUS)

Reference: Domain-independent User Simulation with Transformers for
Task-oriented Dialogue Systems, Lin et al., 2021.
See: https://arxiv.org/abs/2106.08838

Implementation is adapted from the description in the paper and the original
implementation by the authors:
https://gitlab.cs.uni-duesseldorf.de/general/dsml/tus_public
"""

import logging
import random
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List

import torch

from dialoguekit.core.dialogue_act import DialogueAct
from dialoguekit.core.slot_value_annotation import SlotValueAnnotation
from dialoguekit.core.utterance import Utterance
from dialoguekit.nlg.nlg_conditional import ConditionalNLG
from dialoguekit.nlu.nlu import NLU
from dialoguekit.participant import DialogueParticipant
from usersimcrs.core.simulation_domain import SimulationDomain
from usersimcrs.dialogue_management.dialogue_state_tracker import (
DialogueStateTracker,
)
from usersimcrs.items.item_collection import ItemCollection
from usersimcrs.simulator.neural.core.feature_handler import (
FeatureMask,
FeatureVector,
)
from usersimcrs.simulator.neural.core.transformer import (
TransformerEncoderModel,
)
from usersimcrs.simulator.neural.tus.tus_feature_handler import (
TUSFeatureHandler,
)
from usersimcrs.simulator.user_simulator import UserSimulator

logger = logging.getLogger(__name__)


class TUS(UserSimulator):
def __init__(
self,
id: str,
domain: SimulationDomain,
item_collection: ItemCollection,
nlu: NLU,
feature_handler: TUSFeatureHandler,
dialogue_state_tracker: DialogueStateTracker,
network_config: Dict[str, Any],
nlg: ConditionalNLG,
) -> None:
"""Initializes the Transformer-based User Simulator (TUS).

Args:
id: Simulator ID.
domain: Domain knowledge.
item_collection: Collection of items.
nlu: NLU module.
feature_handler: Feature handler.
dialogue_state_tracker: Dialogue state tracker.
network_config: Network configuration.
nlg: NLG module generating textual responses.
"""
super().__init__(id=id, domain=domain, item_collection=item_collection)
self._nlu = nlu
self._nlg = nlg
self._tus_feature_handler = feature_handler
self._user_policy_network = TransformerEncoderModel(**network_config)
self._dialogue_state_tracker = dialogue_state_tracker
self._last_user_actions: DefaultDict[str, torch.Tensor] = defaultdict(
lambda: torch.tensor([])
)
self._last_turn_input: torch.Tensor = None

def initialize(self) -> None:
"""Initializes the user simulator."""
self._dialogue_state_tracker.reset_state()
self._last_user_actions.clear()
self._last_turn_input = None

def _generate_response(self, agent_utterance: Utterance) -> Utterance:
"""Generates response to the agent utterance.

Args:
agent_utterance: Agent utterance.

Returns:
User utterance.
"""
previous_state = self._dialogue_state_tracker.get_current_state()
# 1. Perform NLU on the agent utterance, i.e., extract dialogue acts.
agent_dialogue_acts = self._nlu.extract_dialogue_acts(agent_utterance)

# 2. Update dialogue state based on the agent dialogue acts.
self._dialogue_state_tracker.update_state(
dialogue_acts=agent_dialogue_acts,
participant=DialogueParticipant.AGENT,
)

# 3. Extract features for the current turn.
turn_feature, mask = self._tus_feature_handler.build_input_vector(
agent_dialogue_acts=agent_dialogue_acts,
previous_state=previous_state,
state=self._dialogue_state_tracker.get_current_state(),
information_need=self.information_need,
user_action_vectors=self._last_user_actions,
)

# 5. Predict user dialogue acts based on the features.
user_dialogue_acts = self.predict_user_dialogue_acts(
turn_feature, mask, self._tus_feature_handler.action_slots
)

# 6. Generate user utterance based on the predicted dialogue acts.
response = self._nlg.generate_utterance_text(user_dialogue_acts)
response.participant = DialogueParticipant.USER

# 7. Update dialogue state based on the user dialogue acts.
self._dialogue_state_tracker.update_state(
dialogue_acts=response.dialogue_acts,
participant=DialogueParticipant.USER,
)

return response

def predict_user_dialogue_acts(
self,
features: List[FeatureVector],
mask: FeatureMask,
action_slots: List[str],
) -> List[DialogueAct]:
"""Predicts user dialogue acts based on the features.

Args:
features: Feature vector.
mask: Mask vector.
action_slots: Action slots used to predict the user action per slot.

Returns:
Predicted user dialogue acts.
"""
output = self._user_policy_network(features, mask)
# fmt: off
output = output[
:, 1 : self._tus_feature_handler.max_turn_feature_length + 1, : # noqa: E203, E501
]
# fmt: on

slot_outputs: Dict[str, int] = {}
for index, slot_name in enumerate(action_slots):
o = int(torch.argmax(output[0, index + 1, :]).item())
assert o in range(6), f"Invalid output: {o}"
slot_outputs[slot_name] = o
# One-hot encoding of user action for the slot
o_i = torch.zeros(6)
o_i[o] = 1
self._last_user_actions[slot_name] = o_i

user_dialogue_acts = self._parse_policy_output(
action_slots, slot_outputs
)
return user_dialogue_acts

def _parse_policy_output(
self, action_slots: List[str], slot_outputs: Dict[str, int]
) -> List[DialogueAct]:
"""Parses the policy output to dialogue acts.

Args:
action_slots: Action slots.
slot_outputs: Output per slot.

Returns:
Dialogue acts.
"""
belief_state = (
self._dialogue_state_tracker.get_current_state().belief_state
)
dialogue_acts = []

for slot in action_slots:
o = slot_outputs[slot]
dialogue_act = DialogueAct()

# Default intent is "inform"
dialogue_act.intent = "inform"

# Determine the value of the slot
if o == 1:
# The slot's value is requested by the user
dialogue_act.intent = "request"
dialogue_act.annotations.append(SlotValueAnnotation(slot))
elif o == 2:
# The slot's value is set to "dontcare"
dialogue_act.annotations.append(
SlotValueAnnotation(slot, "dontcare")
)
elif o == 3:
# The slot's value is taken from the information need
if slot in self.information_need.constraints.keys():
dialogue_act.annotations.append(
SlotValueAnnotation(
slot, self.information_need.constraints[slot]
)
)
elif o == 4:
# The slot's value was previously mentioned and is retrieved
# from the belief state
if slot in belief_state.keys():
dialogue_act.annotations.append(
SlotValueAnnotation(slot, belief_state[slot])
)
elif o == 5:
# The slot's value in the information need is randomly modified
value = random.choice(
list(
self._item_collection.get_possible_property_values(slot)
)
)
self.information_need.constraints[slot] = value
dialogue_act.annotations.append(
SlotValueAnnotation(slot, value)
)
else:
logger.warning(f"{slot} is not mentioned in this turn.")
continue

dialogue_acts.append(dialogue_act)

return dialogue_acts
35 changes: 18 additions & 17 deletions usersimcrs/simulator/neural/tus/tus_feature_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import torch

from dialoguekit.core.annotated_utterance import AnnotatedUtterance
from dialoguekit.core.annotation import Annotation
from dialoguekit.core.dialogue_act import DialogueAct
from dialoguekit.core.slot_value_annotation import SlotValueAnnotation
from usersimcrs.core.information_need import InformationNeed
from usersimcrs.core.simulation_domain import SimulationDomain
from usersimcrs.dialogue_management.dialogue_state import DialogueState
Expand Down Expand Up @@ -342,7 +342,7 @@ def build_input_vector(
a special token. Note that is inferred from the list of feature vectors
provided.
The input vector is structured as follows:
[CLS] $V^t$ [SEP] $V^{t-1}$ [SEP] ... [SEP] $V^{t-n}$ [SEP] {padding},
[CLS] $V^t$ [SEP] $V^{t-1}$ [SEP] ... [SEP] $V^{t-n}$ [SEP] {padding}
where {padding} indicates the padding to reach the maximum length.

Args:
Expand Down Expand Up @@ -429,7 +429,7 @@ def get_label_vector(

def _get_label(
self,
annotation: Annotation,
slot_value_annotation: SlotValueAnnotation,
current_state: DialogueState,
information_need: InformationNeed,
) -> int:
Expand All @@ -445,42 +445,43 @@ def _get_label(
5: The slot's value is randomly chosen.

Args:
annotation: Annotation.
slot_value_annotation: Slot value pair annotation.
current_state: Current state.
information_need: Information need.

Returns:
Label.
"""
if annotation.value == "dontcare":
if slot_value_annotation.value == "dontcare":
return 1
elif annotation.value is None:
elif slot_value_annotation.value is None:
# The value is requested by the user
return 2
elif annotation.value == information_need.get_constraint_value(
annotation.slot
) or annotation.value == information_need.requested_slots.get(
annotation.slot
elif (
slot_value_annotation.value
== information_need.get_constraint_value(slot_value_annotation.slot)
or slot_value_annotation.value
== information_need.requested_slots.get(slot_value_annotation.slot)
):
# The value is taken from the information need
return 3
elif annotation.value == current_state.belief_state.get(
annotation.slot
elif slot_value_annotation.value == current_state.belief_state.get(
slot_value_annotation.slot
):
# The value was previously mentioned and is
# retrieved from the belief state
return 4
elif (
annotation.slot
slot_value_annotation.slot
in [
information_need.constraints.keys(),
information_need.requested_slots.keys(),
current_state.belief_state.keys(),
]
) and annotation.value not in [
information_need.get_constraint_value(annotation.slot),
information_need.requested_slots.get(annotation.slot),
current_state.belief_state.get(annotation.slot),
) and slot_value_annotation.value not in [
information_need.get_constraint_value(slot_value_annotation.slot),
information_need.requested_slots.get(slot_value_annotation.slot),
current_state.belief_state.get(slot_value_annotation.slot),
]:
# The slot's value is randomly chosen
return 5
Expand Down
8 changes: 7 additions & 1 deletion usersimcrs/simulator/user_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@ def __init__(
domain: SimulationDomain,
item_collection: ItemCollection,
) -> None:
"""Initializes the user simulator."""
"""Initializes the user simulator.

Args:
id: User ID.
domain: Domain knowledge.
item_collection: Collection of items.
"""
super().__init__(id, UserType.SIMULATOR)
self._domain = domain
self._item_collection = item_collection
Expand Down
Loading