made environment for BensonV4
This commit is contained in:
@@ -115,6 +115,9 @@ public class AgentController : Agent
|
||||
|
||||
public override void OnActionReceived(ActionBuffers actions)
|
||||
{
|
||||
|
||||
print("L");
|
||||
|
||||
// Actions size = 2 [vertical speed, horizontal speed] = [-1..1, -1..1] // discrete = [{0, 1, 2}, {0, 1, 2}] = [{-1, 0, 1}...]
|
||||
float vInput = 0;
|
||||
float hInput = 0;
|
||||
@@ -133,7 +136,7 @@ public class AgentController : Agent
|
||||
|
||||
if (vInput == 1f)
|
||||
{
|
||||
AddReward(0.001f);
|
||||
AddReward(0.02f);
|
||||
}
|
||||
|
||||
// give benson mental pain for existing (punishment for maximizing first checkpoint by standing still)
|
||||
226
Assets/Scripts/AgentControllerV4.cs
Normal file
226
Assets/Scripts/AgentControllerV4.cs
Normal file
@@ -0,0 +1,226 @@
|
||||
using System.Collections;
|
||||
using System.Collections.Generic;
|
||||
using UnityEngine;
|
||||
using Unity.MLAgents;
|
||||
using Unity.MLAgents.Sensors;
|
||||
using Unity.MLAgents.Actuators;
|
||||
using UnityEngine.UIElements;
|
||||
using System.Linq;
|
||||
using Unity.Mathematics;
|
||||
using Unity.VisualScripting;
|
||||
using System.Reflection;
|
||||
using System;
|
||||
|
||||
public class AgentControllerV4 : Agent
|
||||
{
|
||||
public float motorForce = 300;
|
||||
public float brakeForce = 500;
|
||||
public float maxSpeed = 400;
|
||||
public float steeringRange = 9;
|
||||
Rigidbody rigidBody;
|
||||
public List<GameObject> checkpoints;
|
||||
Vector3 startPosition;
|
||||
Quaternion startRotation;
|
||||
int currentStep = 0;
|
||||
int stepsSinceCheckpoint = 0;
|
||||
public int maxStepsPerCheckpoint = 300;
|
||||
public int distanceBetweenCheckpoints;
|
||||
|
||||
|
||||
// Start is called before the first frame update
|
||||
void Start()
|
||||
{
|
||||
rigidBody = GetComponent<Rigidbody>();
|
||||
// Find all child GameObjects that have the WheelControl script attached
|
||||
startPosition = transform.localPosition;
|
||||
startRotation = transform.localRotation;
|
||||
}
|
||||
|
||||
public override void OnEpisodeBegin()
|
||||
{
|
||||
stepsSinceCheckpoint = 0;
|
||||
|
||||
// reset car
|
||||
transform.localPosition = startPosition;
|
||||
transform.localRotation = startRotation;
|
||||
rigidBody.velocity = Vector3.zero;
|
||||
rigidBody.angularVelocity = Vector3.zero;
|
||||
|
||||
// reset checkpoints
|
||||
foreach (GameObject checkpoint in checkpoints)
|
||||
{
|
||||
checkpoint.GetComponent<Checkpoint>().isCollected = false;
|
||||
}
|
||||
}
|
||||
|
||||
public override void CollectObservations(VectorSensor sensor)
|
||||
{
|
||||
Transform currentCheckpoint = checkpoints[0].transform;
|
||||
foreach (GameObject checkpoint in checkpoints)
|
||||
{
|
||||
bool isCollected = checkpoint.GetComponent<Checkpoint>().isCollected;
|
||||
|
||||
if (!isCollected)
|
||||
{
|
||||
currentCheckpoint = checkpoint.transform;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
print(transform.rotation.y);
|
||||
sensor.AddObservation(transform.rotation.y);
|
||||
|
||||
Vector3 position = transform.localPosition;
|
||||
Vector3 checkpointPosition = currentCheckpoint.localPosition;
|
||||
|
||||
Vector2 toCheckpoint = new Vector2(
|
||||
checkpointPosition.x - position.x,
|
||||
checkpointPosition.z - position.z
|
||||
);
|
||||
|
||||
sensor.AddObservation(toCheckpoint.normalized);
|
||||
sensor.AddObservation(toCheckpoint.magnitude);
|
||||
|
||||
var FullVelocityMagnitude = rigidBody.velocity.magnitude; // Velocity including angular velocity
|
||||
var angularMagnitude = rigidBody.angularVelocity.magnitude;
|
||||
|
||||
var forwardMagnitude = Mathf.Sqrt( Mathf.Pow(FullVelocityMagnitude, 2) - Mathf.Pow(angularMagnitude, 2)); // Agent velocity in forward direction
|
||||
|
||||
if (forwardMagnitude >= 0.001)
|
||||
sensor.AddObservation(forwardMagnitude);
|
||||
else
|
||||
sensor.AddObservation(FullVelocityMagnitude);
|
||||
|
||||
sensor.AddObservation(angularMagnitude);
|
||||
|
||||
}
|
||||
|
||||
public override void OnActionReceived(ActionBuffers actions)
|
||||
{
|
||||
// Actions size = 2 [vertical speed, horizontal speed] = [-1..1, -1..1] // discrete = [{0, 1, 2}, {0, 1, 2}] = [{-1, 0, 1}...]
|
||||
float vInput = 0;
|
||||
float hInput = 0;
|
||||
|
||||
if (actions.DiscreteActions[0] == 0)
|
||||
vInput = -1f;
|
||||
if (actions.DiscreteActions[0] == 1)
|
||||
vInput = 1f;
|
||||
|
||||
if (actions.DiscreteActions[1] == 0)
|
||||
hInput = -1f;
|
||||
if (actions.DiscreteActions[1] == 1)
|
||||
hInput = 1f;
|
||||
|
||||
// reward for going forward
|
||||
|
||||
// if (vInput == 1f)
|
||||
// {
|
||||
// AddReward(0.02f);
|
||||
// }
|
||||
|
||||
// give benson mental pain for existing (punishment for maximizing first checkpoint by standing still)
|
||||
AddReward(-0.002f);
|
||||
|
||||
Vector3 movementForce = vInput * motorForce * transform.forward;
|
||||
float carAngle = transform.rotation.eulerAngles.y + steeringRange * hInput;
|
||||
|
||||
float x = transform.rotation.eulerAngles.x;
|
||||
float z = transform.rotation.eulerAngles.z;
|
||||
|
||||
transform.rotation = Quaternion.Euler(x, carAngle, z);
|
||||
|
||||
rigidBody.AddForce(movementForce, ForceMode.Impulse);
|
||||
|
||||
// rewards
|
||||
Transform currentCheckpoint = checkpoints[0].transform;
|
||||
foreach (GameObject checkpoint in checkpoints)
|
||||
{
|
||||
bool isCollected = checkpoint.GetComponent<Checkpoint>().isCollected;
|
||||
|
||||
if (!isCollected)
|
||||
{
|
||||
currentCheckpoint = checkpoint.transform;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
float checkpintDistance = distanceToCheckpoint(currentCheckpoint);
|
||||
|
||||
float reward = (1 - Mathf.InverseLerp(0, distanceBetweenCheckpoints, checkpintDistance)) / 1000;
|
||||
|
||||
AddReward(reward);
|
||||
|
||||
if (checkpintDistance < 0.1f)
|
||||
{
|
||||
currentCheckpoint.GetComponent<Checkpoint>().isCollected = true;
|
||||
stepsSinceCheckpoint = 0;
|
||||
|
||||
if (currentCheckpoint == checkpoints[checkpoints.Count - 1].transform)
|
||||
{
|
||||
AddReward(10f);
|
||||
EndEpisode();
|
||||
}
|
||||
AddReward(1.0f);
|
||||
}
|
||||
|
||||
currentStep += 1;
|
||||
stepsSinceCheckpoint += 1;
|
||||
|
||||
if (stepsSinceCheckpoint >= maxStepsPerCheckpoint)
|
||||
{
|
||||
stepsSinceCheckpoint = 0;
|
||||
EndEpisode();
|
||||
}
|
||||
}
|
||||
|
||||
public override void Heuristic(in ActionBuffers actionsOut)
|
||||
{
|
||||
var discreteActionsOut = actionsOut.DiscreteActions;
|
||||
|
||||
discreteActionsOut[0] = 2;
|
||||
discreteActionsOut[1] = 2;
|
||||
|
||||
if (Input.GetAxis("Vertical") < -0.5)
|
||||
discreteActionsOut[0] = 0;
|
||||
if (Input.GetAxis("Vertical") > 0.5)
|
||||
discreteActionsOut[0] = 1;
|
||||
|
||||
if (Input.GetAxis("Horizontal") < -0.5)
|
||||
discreteActionsOut[1] = 0;
|
||||
if (Input.GetAxis("Horizontal") > 0.5)
|
||||
discreteActionsOut[1] = 1;
|
||||
}
|
||||
|
||||
// finds distance from agent to closest point on the checkpoint line
|
||||
float distanceToCheckpoint(Transform checkpoint)
|
||||
{
|
||||
var closestPoint = checkpoint.GetComponent<Collider>().ClosestPointOnBounds(transform.position);
|
||||
var distanceToCheckpoint = Vector3.Distance(transform.position, closestPoint);
|
||||
return distanceToCheckpoint;
|
||||
}
|
||||
|
||||
// find angle from agent to middle of checkpoint line.
|
||||
float angleToCheckpoint(Transform checkpoint)
|
||||
{
|
||||
Vector3 checkpointDirection = checkpoint.localPosition - transform.localPosition;
|
||||
|
||||
float angle = Vector3.SignedAngle(transform.forward, checkpointDirection, Vector3.up);
|
||||
return angle;
|
||||
}
|
||||
|
||||
// punishment for hitting a wall
|
||||
private void OnCollisionEnter(Collision other) {
|
||||
if (other.gameObject.tag == "Wall")
|
||||
{
|
||||
AddReward(-0.05f);
|
||||
}
|
||||
}
|
||||
|
||||
// punishment for staying at a wall
|
||||
private void OnCollisionStay(Collision other) {
|
||||
if (other.gameObject.tag == "Wall")
|
||||
{
|
||||
AddReward(-0.05f);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
fileFormatVersion: 2
|
||||
guid: 8c9672ffa7e21bf41b27f35da94be659
|
||||
guid: f5cda3de98f45464999f00bdd795f2a0
|
||||
MonoImporter:
|
||||
externalObjects: {}
|
||||
serializedVersion: 2
|
||||
@@ -1,44 +0,0 @@
|
||||
using System.Collections;
|
||||
using System.Collections.Generic;
|
||||
using UnityEngine;
|
||||
|
||||
public class CarControl : MonoBehaviour
|
||||
{
|
||||
|
||||
public float motorTorque = 2000;
|
||||
public float maxSpeed = 20;
|
||||
public float steeringRange = 30;
|
||||
//public float steeringRangeAtMaxSpeed = 10;
|
||||
WheelControl[] wheels;
|
||||
Rigidbody rigidBody;
|
||||
// Start is called before the first frame update
|
||||
void Start()
|
||||
{
|
||||
rigidBody = GetComponent<Rigidbody>();
|
||||
|
||||
// Find all child GameObjects that have the WheelControl script attached
|
||||
wheels = GetComponentsInChildren<WheelControl>();
|
||||
}
|
||||
|
||||
// Update is called once per frame
|
||||
void FixedUpdate()
|
||||
{
|
||||
float vInput = Input.GetAxis("Vertical");
|
||||
float hInput = Input.GetAxis("Horizontal");
|
||||
|
||||
foreach (var wheel in wheels)
|
||||
{
|
||||
// Apply steering to Wheel colliders that have "Steerable" enabled
|
||||
if (wheel.steerable)
|
||||
{
|
||||
wheel.WheelCollider.steerAngle = hInput * steeringRange;
|
||||
}
|
||||
|
||||
// Apply torque to Wheel colliders that have "Motorized" enabled
|
||||
if (wheel.motorized)
|
||||
{
|
||||
wheel.WheelCollider.motorTorque = vInput * motorTorque;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
fileFormatVersion: 2
|
||||
guid: a23c7d66d3ff94847a946cf7b30ca1b7
|
||||
MonoImporter:
|
||||
externalObjects: {}
|
||||
serializedVersion: 2
|
||||
defaultReferences: []
|
||||
executionOrder: 0
|
||||
icon: {instanceID: 0}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
||||
@@ -5,4 +5,20 @@ using UnityEngine;
|
||||
public class Checkpoint : MonoBehaviour
|
||||
{
|
||||
public bool isCollected = false;
|
||||
MeshRenderer meshRenderer;
|
||||
|
||||
private void Start() {
|
||||
meshRenderer = GetComponent<MeshRenderer>();
|
||||
}
|
||||
|
||||
private void Update() {
|
||||
if (isCollected)
|
||||
{
|
||||
meshRenderer.enabled = false;
|
||||
}
|
||||
else
|
||||
{
|
||||
meshRenderer.enabled = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
using System.Collections;
|
||||
using System.Collections.Generic;
|
||||
using UnityEngine;
|
||||
|
||||
public class VehichleControl : MonoBehaviour
|
||||
{
|
||||
public InputController InputCtrl;
|
||||
[Tooltip("Set ref in order of FL, FR, RL, RR")]
|
||||
public WheelCollider[] WheelColliders;
|
||||
|
||||
[Tooltip("Set ref of wheel meshes in order of FL, FR, RL, RR")]
|
||||
public Transform[] Wheels;
|
||||
|
||||
public Transform CenterOfMass;
|
||||
|
||||
public int Force;
|
||||
public int Angle;
|
||||
public int BrakeForce;
|
||||
|
||||
private void Drive()
|
||||
{
|
||||
WheelColliders[0].motorTorque = WheelColliders[1].motorTorque = InputCtrl.Vertical * Force;
|
||||
}
|
||||
|
||||
private void Steer()
|
||||
{
|
||||
WheelColliders[0].steerAngle = WheelColliders[1].steerAngle = InputCtrl.Horizontal * Angle;
|
||||
}
|
||||
|
||||
private void Brake()
|
||||
{
|
||||
WheelColliders[0].brakeTorque = WheelColliders[1].brakeTorque = InputCtrl.Brake * BrakeForce;
|
||||
}
|
||||
|
||||
private void UpdateWheelMovements()
|
||||
{
|
||||
for (var i = 0; i < Wheels.Length; i++)
|
||||
{
|
||||
Vector3 pos;
|
||||
Quaternion rot;
|
||||
WheelColliders[i].GetWorldPose(out pos, out rot);
|
||||
Wheels[i].transform.position = pos;
|
||||
Wheels[i].transform.rotation = rot;
|
||||
}
|
||||
}
|
||||
|
||||
private void FixedUpdate()
|
||||
{
|
||||
Steer();
|
||||
Drive();
|
||||
Brake();
|
||||
UpdateWheelMovements();
|
||||
}
|
||||
|
||||
private void Start()
|
||||
{
|
||||
GetComponent<Rigidbody>().centerOfMass = CenterOfMass.localPosition;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user