Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flexible source dir #72

Open
wants to merge 5 commits into
base: app-ns-3.36+
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions model/ns3gym/ns3gym/ns3env.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

class Ns3ZmqBridge(object):
"""docstring for Ns3ZmqBridge"""
def __init__(self, port=0, startSim=True, simSeed=0, simArgs={}, debug=False):
def __init__(self, port=0, startSim=True, simSeed=0, simArgs={}, debug=False, src_dir=os.getcwd(), sim_file=""):
super(Ns3ZmqBridge, self).__init__()
port = int(port)
self.port = port
Expand All @@ -35,6 +35,8 @@ def __init__(self, port=0, startSim=True, simSeed=0, simArgs={}, debug=False):
self.simPid = None
self.wafPid = None
self.ns3Process = None
self.src_dir = src_dir
self.sim_file = sim_file

context = zmq.Context()
self.socket = context.socket(zmq.REP)
Expand Down Expand Up @@ -63,7 +65,7 @@ def __init__(self, port=0, startSim=True, simSeed=0, simArgs={}, debug=False):

if self.startSim:
# run simulation script
self.ns3Process = start_sim_script(port, simSeed, simArgs, debug)
self.ns3Process = start_sim_script(sim_file, port, simSeed, simArgs, debug, src_dir)
else:
print("Waiting for simulation script to connect on port: tcp://localhost:{}".format(port))
print('Please start proper ns-3 simulation script using ./waf --run "..."')
Expand Down Expand Up @@ -361,13 +363,15 @@ def _pack_data(self, actions, spaceDesc):


class Ns3Env(gym.Env):
def __init__(self, stepTime=0, port=0, startSim=True, simSeed=0, simArgs={}, debug=False):
def __init__(self, stepTime=0, port=0, startSim=True, simSeed=0, simArgs={}, debug=False, src_dir=os.getcwd(), sim_file=None):
self.stepTime = stepTime
self.port = port
self.startSim = startSim
self.simSeed = simSeed
self.simArgs = simArgs
self.debug = debug
self.src_dir = src_dir
self.sim_file = sim_file

# Filled in reset function
self.ns3ZmqBridge = None
Expand All @@ -378,7 +382,7 @@ def __init__(self, stepTime=0, port=0, startSim=True, simSeed=0, simArgs={}, deb
self.state = None
self.steps_beyond_done = None

self.ns3ZmqBridge = Ns3ZmqBridge(self.port, self.startSim, self.simSeed, self.simArgs, self.debug)
self.ns3ZmqBridge = Ns3ZmqBridge(self.port, self.startSim, self.simSeed, self.simArgs, self.debug, self.src_dir, self.sim_file)
self.ns3ZmqBridge.initialize_env(self.stepTime)
self.action_space = self.ns3ZmqBridge.get_action_space()
self.observation_space = self.ns3ZmqBridge.get_observation_space()
Expand Down Expand Up @@ -413,7 +417,7 @@ def reset(self):
self.ns3ZmqBridge = None

self.envDirty = False
self.ns3ZmqBridge = Ns3ZmqBridge(self.port, self.startSim, self.simSeed, self.simArgs, self.debug)
self.ns3ZmqBridge = Ns3ZmqBridge(self.port, self.startSim, self.simSeed, self.simArgs, self.debug, self.src_dir, self.sim_file)
self.ns3ZmqBridge.initialize_env(self.stepTime)
self.action_space = self.ns3ZmqBridge.get_action_space()
self.observation_space = self.ns3ZmqBridge.get_observation_space()
Expand Down
20 changes: 9 additions & 11 deletions model/ns3gym/ns3gym/start_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def find_ns3_path(cwd):
if fname == "ns3":
found = True
ns3_path = os.path.join(my_dir, fname)
break
if fname == "ns-3-dev":
found = True
ns3_path = os.path.join(my_dir, fname, "ns3")

my_dir = os.path.dirname(my_dir)

Expand Down Expand Up @@ -67,18 +69,17 @@ def build_ns3_project(debug=True):
os.chdir(cwd)


def start_sim_script(port=5555, sim_seed=0, sim_args={}, debug=False):
def start_sim_script(sim_file, port=5555, sim_seed=0, sim_args={}, debug=False, src_dir=os.getcwd()):
"""
Actually run the ns3 scenario
"""
cwd = os.getcwd()
sim_script_name = os.path.basename(cwd)
ns3_path = find_ns3_path(cwd)
sim_script_name = os.path.basename(src_dir)
ns3_path = find_ns3_path(src_dir)
base_ns3_dir = os.path.dirname(ns3_path)

os.chdir(base_ns3_dir)

ns3_string = ns3_path + ' run "' + sim_script_name
ns3_string = ns3_path + ' run "' + sim_script_name + '/' + str(sim_file)

if port:
ns3_string += ' --openGymPort=' + str(port)
Expand All @@ -87,10 +88,7 @@ def start_sim_script(port=5555, sim_seed=0, sim_args={}, debug=False):
ns3_string += ' --simSeed=' + str(sim_seed)

for key, value in sim_args.items():
ns3_string += " "
ns3_string += str(key)
ns3_string += "="
ns3_string += str(value)
ns3_string += " --" + str(key) + "=" + str(value)

ns3_string += '"'

Expand Down Expand Up @@ -133,5 +131,5 @@ def start_sim_script(port=5555, sim_seed=0, sim_args={}, debug=False):
print("Started ns3 simulation script, Process Id: ", ns3_proc.pid)

# go back to my dir
os.chdir(cwd)
os.chdir(src_dir)
return ns3_proc
2 changes: 1 addition & 1 deletion model/ns3gym/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pyzmq
numpy
protobuf==3.20.3
protobuf==3.19.0
gym