Skip to content

Commit

Permalink
Add more traces and fix bug (#15)
Browse files Browse the repository at this point in the history
Signed-off-by: Ze Gan <[email protected]>
  • Loading branch information
Pterosaur authored Apr 13, 2021
1 parent 72dd6d9 commit 3b5f161
Show file tree
Hide file tree
Showing 11 changed files with 13,999 additions and 24 deletions.
1,230 changes: 1,230 additions & 0 deletions alphartc_gym/tests/data/4G_3mbps.json

Large diffs are not rendered by default.

2,656 changes: 2,656 additions & 0 deletions alphartc_gym/tests/data/4G_500kbps.json

Large diffs are not rendered by default.

2,671 changes: 2,671 additions & 0 deletions alphartc_gym/tests/data/4G_700kbps.json

Large diffs are not rendered by default.

1,234 changes: 1,234 additions & 0 deletions alphartc_gym/tests/data/5G_12mbps.json

Large diffs are not rendered by default.

1,226 changes: 1,226 additions & 0 deletions alphartc_gym/tests/data/5G_13mbps.json

Large diffs are not rendered by default.

2,892 changes: 2,892 additions & 0 deletions alphartc_gym/tests/data/WIRED_200kbps.json

Large diffs are not rendered by default.

1,238 changes: 1,238 additions & 0 deletions alphartc_gym/tests/data/WIRED_35mbps.json

Large diffs are not rendered by default.

806 changes: 806 additions & 0 deletions alphartc_gym/tests/data/WIRED_900kbs.json

Large diffs are not rendered by default.

30 changes: 15 additions & 15 deletions alphartc_gym/tests/test_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from alphartc_gym import gym

import os
import glob


def test_basic():
total_stats = []
Expand Down Expand Up @@ -43,18 +45,16 @@ def test_multiple_instances():

def test_trace():
total_stats = []
trace_path = os.path.join(
os.path.dirname(__file__),
"data",
"trace_example.json")
g = gym.Gym("test_gym")
g.reset(trace_path=trace_path, report_interval_ms=60, duration_time_ms=0)
while True:
stats, done = g.step(1000)
if not done:
total_stats += stats
else:
break
assert(total_stats)
for stats in total_stats:
assert(isinstance(stats, dict))
trace_files = os.path.join(os.path.dirname(__file__), "data", "*.json")
for trace_file in glob.glob(trace_files):
g = gym.Gym("test_gym")
g.reset(trace_path=trace_file, report_interval_ms=60, duration_time_ms=0)
while True:
stats, done = g.step(1000)
if not done:
total_stats += stats
else:
break
assert(total_stats)
for stats in total_stats:
assert(isinstance(stats, dict))
36 changes: 27 additions & 9 deletions ns-app/scratch/webrtc_test/trace_player.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "trace_player.h"

#include "ns3/point-to-point-net-device.h"
#include "ns3/simulator.h"
#include "ns3/string.h"
#include "ns3/error-model.h"
Expand Down Expand Up @@ -41,16 +40,24 @@ void TracePlayer::LoadTrace() {
auto uplink_traces = j["uplink"]["trace_pattern"];
std::vector<TraceItem> traces;
traces.reserve(uplink_traces.size());

TraceItem last_available_value;
for (const auto &trace: uplink_traces) {
TraceItem ti;
ti.capacity_ = lexical_cast<decltype(ti.capacity_)>(trace["capacity"]);
ti.duration_ms_ = lexical_cast<decltype(ti.duration_ms_)>(trace["duration"]);
ti.capacity_ = lexical_cast<double>(trace["capacity"]);
ti.duration_ms_ = lexical_cast<double>(trace["duration"]);
if (trace.find("loss") != trace.end()) {
ti.loss_rate_ = lexical_cast<double>(trace["loss"]);
last_available_value.loss_rate_ = ti.loss_rate_;
} else {
ti.loss_rate_ = last_available_value.loss_rate_;
}
if (trace.find("rtt") != trace.end()) {
ti.rtt_ms_ = lexical_cast<std::uint64_t>(trace["rtt"]);
}
last_available_value.rtt_ms_ = ti.rtt_ms_;
} else {
ti.rtt_ms_ = last_available_value.rtt_ms_;
}
traces.push_back(std::move(ti));
}
traces_.swap(traces);
Expand All @@ -74,16 +81,27 @@ void TracePlayer::PlayTrace(size_t trace_index) {
auto device =
dynamic_cast<PointToPointNetDevice *>(PeekPointer(node->GetDevice(j)));
if (device) {
device->SetDataRate(DataRate(trace.capacity_ * 1e3));
// set loss rate in every device
if (trace.loss_rate_) {
Ptr<RateErrorModel> em = CreateObjectWithAttributes<RateErrorModel> ("RanVar", StringValue("ns3::UniformRandomVariable[Min=0.0|Max=1.0]"), \
"ErrorRate", DoubleValue (trace.loss_rate_.value()), \
"ErrorUnit", StringValue("ERROR_UNIT_PACKET"));
device->SetAttribute("ReceiveErrorModel", PointerValue (em));
SetLossRate(device, trace.loss_rate_.value());
}
if (trace.capacity_ == 0) {
SetLossRate(device, 1.0);
} else {
device->SetDataRate(DataRate(trace.capacity_ * 1e3));
if (!trace.loss_rate_) {
SetLossRate(device, 0.0);
}
}
}
}
}
Simulator::Schedule(MilliSeconds(trace.duration_ms_), &TracePlayer::PlayTrace, this, trace_index + 1);
}

void TracePlayer::SetLossRate(ns3::PointToPointNetDevice *device, double loss_rate) {
Ptr<RateErrorModel> em = CreateObjectWithAttributes<RateErrorModel> ("RanVar", StringValue("ns3::UniformRandomVariable[Min=0.0|Max=1.0]"), \
"ErrorRate", DoubleValue (loss_rate), \
"ErrorUnit", StringValue("ERROR_UNIT_PACKET"));
device->SetAttribute("ReceiveErrorModel", PointerValue (em));
}
4 changes: 4 additions & 0 deletions ns-app/scratch/webrtc_test/trace_player.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "ns3/point-to-point-net-device.h"
#include "ns3/node-container.h"

#include <string>
Expand All @@ -25,7 +26,10 @@ class TracePlayer {

void PlayTrace(size_t trace_index = 0);

void SetLossRate(ns3::PointToPointNetDevice *device, double loss_rate);

const std::string source_file_;
std::vector<TraceItem> traces_;
ns3::NodeContainer &nodes_;

};

0 comments on commit 3b5f161

Please sign in to comment.