-
Notifications
You must be signed in to change notification settings - Fork 1
/
query.py
96 lines (77 loc) · 4.16 KB
/
query.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from functools import lru_cache
import textworld
from textworld.logic import State, Rule, Proposition, Variable
FOOD_FACTS = ["sliced", "diced", "chopped", "cut", "uncut", "cooked", "burned",
"grilled", "fried", "roasted", "raw", "edible", "inedible"]
@lru_cache()
def _rules_predicates_scope():
rules = [
Rule.parse("query :: at(P, r) -> at(P, r)"),
Rule.parse("query :: at(P, r) & at(o, r) -> at(o, r)"),
Rule.parse("query :: at(P, r) & at(d, r) -> at(d, r)"),
Rule.parse("query :: at(P, r) & at(s, r) -> at(s, r)"),
Rule.parse("query :: at(P, r) & at(c, r) -> at(c, r)"),
Rule.parse("query :: at(P, r) & at(s, r) & on(o, s) -> on(o, s)"),
Rule.parse("query :: at(P, r) & at(c, r) & open(c) -> open(c)"),
Rule.parse("query :: at(P, r) & at(c, r) & closed(c) -> closed(c)"),
Rule.parse("query :: at(P, r) & at(c, r) & open(c) & in(o, c) -> in(o, c)"),
Rule.parse("query :: at(P, r) & link(r, d, r') & open(d) -> open(d)"),
Rule.parse("query :: at(P, r) & link(r, d, r') & closed(d) -> closed(d)"),
Rule.parse("query :: at(P, r) & link(r, d, r') & north_of(r', r) -> north_of(d, r)"),
Rule.parse("query :: at(P, r) & link(r, d, r') & south_of(r', r) -> south_of(d, r)"),
Rule.parse("query :: at(P, r) & link(r, d, r') & west_of(r', r) -> west_of(d, r)"),
Rule.parse("query :: at(P, r) & link(r, d, r') & east_of(r', r) -> east_of(d, r)"),
]
rules += [Rule.parse("query :: at(P, r) & at(f, r) & {fact}(f) -> {fact}(f)".format(fact=fact)) for fact in FOOD_FACTS]
rules += [Rule.parse("query :: at(P, r) & at(s, r) & on(f, s) & {fact}(f) -> {fact}(f)".format(fact=fact)) for fact in FOOD_FACTS]
rules += [Rule.parse("query :: at(P, r) & at(c, r) & open(c) & in(f, c) & {fact}(f) -> {fact}(f)".format(fact=fact)) for fact in FOOD_FACTS]
return rules
@lru_cache()
def _rules_predicates_inv():
rules = [
Rule.parse("query :: in(o, I) -> in(o, I)"),
]
rules += [Rule.parse("query :: in(f, I) & {fact}(f) -> {fact}(f)".format(fact=fact)) for fact in FOOD_FACTS]
return rules
def find_predicates_in_scope(state):
actions = state.all_applicable_actions(_rules_predicates_scope())
predicates = set(action.postconditions[0] for action in actions)
entities = set(name for p in predicates for name in p.names)
for fact in state.facts:
if fact.arguments[0].name in entities:
predicates.add(fact)
return predicates
def find_predicates_in_inventory(state):
actions = state.all_applicable_actions(_rules_predicates_inv())
predicates = [action.postconditions[0] for action in actions]
predicates = set(action.postconditions[0] for action in actions)
entities = set(name for p in predicates for name in p.names)
for fact in state.facts:
if fact.arguments[0].name in entities:
predicates.add(fact)
return predicates
def process_facts(prev_facts, info_game, info_facts, info_last_action, cmd):
"""
process game knowledge base facts due to new action command - not exactly sure how it works but used in the reward helper for existence and attribute questions
:param prev_facts: Previous facts given.
:param info_game: the game info from the environment i.e the knowledge base from info object
:param info_facts: facts from game environment info object
:param info_last_action: last action from info object
:param cmd: the command performed
"""
kb = textworld.Game.deserialize(info_game).kb
if prev_facts is None:
facts = set()
else:
if cmd == "inventory": # Bypassing TextWorld's action detection.
facts = set(find_predicates_in_inventory(State(kb.logic, info_facts)))
return prev_facts | facts
elif info_last_action is None :
return prev_facts # Invalid action, nothing has changed.
state = State(kb.logic, prev_facts | set(info_last_action.preconditions))
success = state.apply(info_last_action)
assert success
facts = set(state.facts)
# Always add facts in sight.
facts |= set(find_predicates_in_scope(State(kb.logic, info_facts)))
return facts