Skip to content

Commit

Permalink
Add another way of estimating the remaining energy
Browse files Browse the repository at this point in the history
  • Loading branch information
RedTachyon committed May 25, 2023
1 parent 39739d0 commit 82a700a
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 7 deletions.
11 changes: 7 additions & 4 deletions Assets/Scripts/Agents/AgentBasic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,8 @@ public class AgentBasic : Agent, IAgent
// protected int Unfrozen = 1;

internal int Collision = 0;




public Vector3 startPosition;

private float _originalHeight;
private float _originalGoalHeight;
Expand Down Expand Up @@ -212,6 +211,8 @@ public override void OnEpisodeBegin()
PreviousPosition = transform.localPosition;
PreviousVelocity = Vector3.zero;

startPosition = transform.localPosition;

PreviousPositionPhysics = transform.localPosition;
PreviouserPositionPhysics = transform.localPosition;
// PreviousVelocityPhysics = Vector3.zero;
Expand Down Expand Up @@ -246,7 +247,9 @@ public override void OnEpisodeBegin()
["r_speedmatch"] = 0f,
["r_speeding"] = 0f,
["r_velocity"] = 0f,
["r_expVelocity"] = 0f
["r_expVelocity"] = 0f,
["r_final"] = 0f,
["r_avgFinal"] = 0f,
};

UpdateParams();
Expand Down
17 changes: 17 additions & 0 deletions Assets/Scripts/MLUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Diagnostics.Contracts;
using System.Linq;
using Managers;
using UnityEngine;
using Random = UnityEngine.Random;

Expand Down Expand Up @@ -249,4 +250,20 @@ public static float EnergyHeuristic(Vector3 position, Vector3 target, float e_s,

// return e_s * time + e_w * speed * speed * time;
}

public static float AverageEnergyHeuristic(Vector3 position, Vector3 target, Vector3 startPosition, float e_s, float e_w)
{
var finalDistance = FlatDistance(position, target); // d'
var totalDistance = FlatDistance(startPosition, target); // d
var timeLimit = Manager.Instance.maxStep * Manager.Instance.DecisionDeltaTime; // T0


var avgSpeed = (totalDistance - finalDistance) / timeLimit; // v'

var remainingTime = finalDistance / avgSpeed; // T'

var energy = e_s * remainingTime + e_w * avgSpeed * avgSpeed * remainingTime;

return energy;
}
}
16 changes: 14 additions & 2 deletions Assets/Scripts/Managers/Manager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,8 @@ private Dictionary<string, float> GetEpisodeStats()
var energiesComplex = new List<float>();
var energiesPlus = new List<float>();
var energiesComplexPlus = new List<float>();
var energiesPlusAvg = new List<float>();
var energiesComplexPlusAvg = new List<float>();
var distances = new List<float>();
var successes = new List<float>();
var numAgents = 0;
Expand All @@ -421,23 +423,33 @@ private Dictionary<string, float> GetEpisodeStats()

// var finalEnergy = 2 * Mathf.Sqrt(agent.e_s * agent.e_w * finalDistance);

var finalEnergy = MLUtils.EnergyHeuristic(agent.transform.localPosition, agent.Goal.localPosition,
var localPosition = agent.transform.localPosition;
var goalPosition = agent.Goal.localPosition;
var finalEnergy = MLUtils.EnergyHeuristic(localPosition, goalPosition,
agent.e_s, agent.e_w);

var finalEnergyAvg = MLUtils.AverageEnergyHeuristic(localPosition, goalPosition, agent.startPosition,
agent.e_s, agent.e_w);

energiesPlus.Add(agent.energySpent + finalEnergy);
energiesComplexPlus.Add(agent.energySpentComplex + finalEnergy);

energiesPlusAvg.Add(agent.energySpent + finalEnergyAvg);
energiesComplexPlusAvg.Add(agent.energySpentComplex + finalEnergyAvg);

distances.Add(agent.distanceTraversed);
successes.Add(agent.CollectedGoal ? 1f : 0f);
numAgents++;
}
Debug.Log($"NumAgents detected in EpisodeStats: {numAgents}");
// Debug.Log($"NumAgents detected in EpisodeStats: {numAgents}");
var stats = new Dictionary<string, float>
{
["e_energy"] = energies.Average(),
["e_energy_complex"] = energiesComplex.Average(),
["e_energy_plus"] = energiesPlus.Average(),
["e_energy_complex_plus"] = energiesComplexPlus.Average(),
["e_energy_plus_avg"] = energiesPlusAvg.Average(),
["e_energy_complex_plus_avg"] = energiesComplexPlusAvg.Average(),
["e_distance"] = distances.Average(),
["e_success"] = successes.Average(),
};
Expand Down
3 changes: 3 additions & 0 deletions Assets/Scripts/Params.cs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ private void Awake()
public float rewardFinal = 1f;
public static float RewFinal => Get("r_final", Instance.rewardFinal);

public float rewardAvgFinal = 1f;
public static float RewAvgFinal => Get("r_avg_final", Instance.rewardAvgFinal);




Expand Down
8 changes: 7 additions & 1 deletion Assets/Scripts/Rewards/DecisionRewarder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,14 @@ public float FinishReward(Transform transform, bool success)
// var finalReward = -2 * Mathf.Sqrt(agent.e_s * agent.e_w) * finalDistance;

reward += Params.RewFinal * penalty;
agent.AddRewardPart(penalty, "final");
agent.AddRewardPart(penalty, "r_final");


var avgPenalty = -MLUtils.AverageEnergyHeuristic(transform.localPosition, agent.Goal.localPosition, agent.startPosition, agent.e_s, agent.e_w);


reward += Params.RewAvgFinal * avgPenalty;
agent.AddRewardPart(avgPenalty, "r_avgFinal");
// TODO: Instead of assuming the optimal velocity, use the average velocity across the trajectory so far
// TODO: Track both of them as a metric, but add a switch to choose which one to use for the reward

Expand Down

0 comments on commit 82a700a

Please sign in to comment.