using System.Collections; using System.Collections.Generic; using UnityEngine; using Unity.MLAgents; using Unity.MLAgents.Sensors; using Unity.MLAgents.Actuators; using UnityEngine.UIElements; using System.Linq; using Unity.Mathematics; using Unity.VisualScripting; using System.Reflection; public class AgentController : Agent { public float motorTorque = 300; public float brakeTorque = 500; public float maxSpeed = 400; public float steeringRange = 9; public float steeringRangeAtMaxSpeed = 7; public float autoBrake = 100; WheelControl[] wheels; Rigidbody rigidBody; public List checkpoints; Vector3 startPosition; Quaternion startRotation; int currentStep = 0; int stepsSinceCheckpoint = 0; public int maxStepsPerCheckpoint = 300; // Start is called before the first frame update [System.Obsolete] void Start() { rigidBody = GetComponent(); // Find all child GameObjects that have the WheelControl script attached wheels = GetComponentsInChildren(); startPosition = transform.localPosition; startRotation = transform.localRotation; } public override void OnEpisodeBegin() { int stepsSinceCheckpoint = 0; // reset wheels foreach (var wheel in wheels) { wheel.WheelCollider.brakeTorque = 0; wheel.WheelCollider.motorTorque = 0; wheel.WheelCollider.steerAngle = 0; } // reset car transform.localPosition = startPosition; transform.localRotation = startRotation; rigidBody.velocity = Vector3.zero; rigidBody.angularVelocity = Vector3.zero; // reset checkpoints foreach (GameObject checkpoint in checkpoints) { checkpoint.GetComponent().isCollected = false; } } public override void CollectObservations(VectorSensor sensor) { Transform currentCheckpoint = checkpoints[0].transform; foreach (GameObject checkpoint in checkpoints) { bool isCollected = checkpoint.GetComponent().isCollected; if (!isCollected) { currentCheckpoint = checkpoint.transform; break; } } sensor.AddObservation(currentCheckpoint.localPosition.x); sensor.AddObservation(currentCheckpoint.localPosition.z); // agent sensor.AddObservation(transform.localPosition.x); sensor.AddObservation(transform.localPosition.z); sensor.AddObservation(transform.rotation.eulerAngles.y); // Agent velocity var FullVelocityMagnitude = rigidBody.velocity.magnitude; // Velocity including angular velocity sensor.AddObservation(FullVelocityMagnitude); // sensor.AddObservation(wheels[0].WheelCollider.motorTorque); // sensor.AddObservation(wheels[0].WheelCollider.brakeTorque); // sensor.AddObservation(wheels[0].WheelCollider.steerAngle); // // calculate forward velocity // var FullVelocityMagnitude = rigidBody.velocity.magnitude; // Velocity including angular velocity // var angularMagnitude = rigidBody.angularVelocity.magnitude; // var forwardMagnitude = Mathf.Sqrt( Mathf.Pow(FullVelocityMagnitude, 2) - Mathf.Pow(angularMagnitude, 2)); // Agent velocity in forward direction // // add obserevations // if (forwardMagnitude >= 0.001) // sensor.AddObservation(forwardMagnitude); // else // sensor.AddObservation(FullVelocityMagnitude); // sensor.AddObservation(angularMagnitude); } public override void OnActionReceived(ActionBuffers actions) { // Actions size = 2 [vertical speed, horizontal speed] = [-1..1, -1..1] // discrete = [{0, 1, 2}, {0, 1, 2}] = [{-1, 0, 1}...] float vInput = 0; float hInput = 0; if (actions.DiscreteActions[0] == 0) vInput = -1f; if (actions.DiscreteActions[0] == 1) vInput = 1f; if (actions.DiscreteActions[1] == 0) hInput = -1f; if (actions.DiscreteActions[1] == 1) hInput = 1f; float forwardSpeed = Vector3.Dot(transform.forward, rigidBody.velocity); // Calculate how close the car is to top speed // as a number from zero to one float speedFactor = Mathf.InverseLerp(0, maxSpeed / 4, forwardSpeed); // Use that to calculate how much torque is available // (zero torque at top speed) float currentMotorTorque = Mathf.Lerp(motorTorque, 0, speedFactor); // …and to calculate how much to steer // (the car steers more gently at top speed) float currentSteerRange = Mathf.Lerp(steeringRange, steeringRangeAtMaxSpeed, speedFactor); // Check whether the user input is in the same direction // as the car's velocity bool isAccelerating = Mathf.Sign(vInput) == Mathf.Sign(forwardSpeed); bool isStopping = vInput == 0; // range bool isBraking = (vInput < 0 && forwardSpeed > 0) || (vInput > 0 && forwardSpeed < 0); if (vInput > 0 && forwardSpeed < 0) { isAccelerating = false; } foreach (var wheel in wheels) { // Apply steering to Wheel colliders that have "Steerable" enabled if (wheel.steerable) { wheel.WheelCollider.steerAngle = hInput * currentSteerRange; } if (isBraking) { wheel.WheelCollider.brakeTorque = Mathf.Abs(vInput) * brakeTorque; //wheel.WheelCollider.motorTorque = 0; } if (isAccelerating) { // Apply torque to Wheel colliders that have "Motorized" enabled if (wheel.motorized) { wheel.WheelCollider.motorTorque = vInput * currentMotorTorque; } wheel.WheelCollider.brakeTorque = 0; } if (isStopping) { // If the user is trying to go in the opposite direction // apply brakes to all wheels wheel.WheelCollider.brakeTorque = Mathf.Abs(vInput) * brakeTorque + autoBrake; if (forwardSpeed < 0) { wheel.WheelCollider.brakeTorque = (Mathf.Abs(vInput) * brakeTorque + autoBrake) * 5; } } } // rewards Transform currentCheckpoint = checkpoints[0].transform; foreach (GameObject checkpoint in checkpoints) { bool isCollected = checkpoint.GetComponent().isCollected; if (!isCollected) { currentCheckpoint = checkpoint.transform; break; } } var closestPoint = currentCheckpoint.GetComponent().ClosestPointOnBounds(transform.localPosition); var distanceToCheckpoint = Vector3.Distance(transform.localPosition, closestPoint); if (distanceToCheckpoint < 0.3f) { Debug.Log(currentCheckpoint.name); currentCheckpoint.GetComponent().isCollected = true; stepsSinceCheckpoint = 0; if (currentCheckpoint == checkpoints[checkpoints.Count - 1].transform) { SetReward(10f); EndEpisode(); Debug.Log("END"); } SetReward(1.0f); } currentStep += 1; stepsSinceCheckpoint += 1; if (stepsSinceCheckpoint >= maxStepsPerCheckpoint) { stepsSinceCheckpoint = 0; EndEpisode(); } } public override void Heuristic(in ActionBuffers actionsOut) { var discreteActionsOut = actionsOut.DiscreteActions; discreteActionsOut[0] = 2; discreteActionsOut[1] = 2; if (Input.GetAxis("Vertical") < -0.5) discreteActionsOut[0] = 0; if (Input.GetAxis("Vertical") > 0.5) discreteActionsOut[0] = 1; if (Input.GetAxis("Horizontal") < -0.5) discreteActionsOut[1] = 0; if (Input.GetAxis("Horizontal") > 0.5) discreteActionsOut[1] = 1; } }