using System.Collections; using System.Collections.Generic; using UnityEngine; using Unity.MLAgents; using Unity.MLAgents.Sensors; using Unity.MLAgents.Actuators; using UnityEngine.UIElements; using System.Linq; using Unity.Mathematics; using Unity.VisualScripting; public class AgentController : Agent { public float motorTorque = 300; public float brakeTorque = 500; public float maxSpeed = 400; public float steeringRange = 9; public float steeringRangeAtMaxSpeed = 7; public float autoBrake = 100; WheelControl[] wheels; Rigidbody rigidBody; public Transform Target; //(35..39, 0.25, -20..-30) Vector3 startPosition; Quaternion startRotation; // Start is called before the first frame update void Start() { rigidBody = GetComponent(); // Find all child GameObjects that have the WheelControl script attached wheels = GetComponentsInChildren(); startPosition = transform.localPosition; startRotation = transform.localRotation; } public override void OnEpisodeBegin() { // reset wheels foreach (var wheel in wheels) { wheel.WheelCollider.brakeTorque = 0; wheel.WheelCollider.motorTorque = 0; wheel.WheelCollider.steerAngle = 0; } transform.localPosition = startPosition; transform.localRotation = startRotation; rigidBody.velocity = Vector3.zero; rigidBody.angularVelocity = Vector3.zero; Target.localPosition = new Vector3(UnityEngine.Random.Range(35f, 39f), 0.25f, UnityEngine.Random.Range(-30f, -20f)); } public override void CollectObservations(VectorSensor sensor) { // Target and Agent positions sensor.AddObservation(Target.localPosition.x); sensor.AddObservation(Target.localPosition.z); sensor.AddObservation(transform.localPosition.x); sensor.AddObservation(transform.localPosition.z); // Agent velocity // calculate forward velocity var FullVelocityMagnitude = rigidBody.velocity.magnitude; // Velocity including angular velocity var angularMagnitude = rigidBody.angularVelocity.magnitude; var forwardMagnitude = Mathf.Sqrt( Mathf.Pow(FullVelocityMagnitude, 2) - Mathf.Pow(angularMagnitude, 2)); // Agent velocity in forward direction // add obserevations if (forwardMagnitude >= 0.001) sensor.AddObservation(forwardMagnitude); else sensor.AddObservation(FullVelocityMagnitude); sensor.AddObservation(angularMagnitude); } void Update() { if (Input.GetKeyDown("space")) { foreach (var wheel in wheels) { wheel.WheelCollider.brakeTorque = 0; wheel.WheelCollider.motorTorque = 0; wheel.WheelCollider.steerAngle = 0; } transform.localPosition = startPosition; transform.localRotation = startRotation; rigidBody.velocity = Vector3.zero; rigidBody.angularVelocity = Vector3.zero; Target.localPosition = new Vector3(UnityEngine.Random.Range(35f, 39f), 0.25f, UnityEngine.Random.Range(-30f, -20f)); } } public override void OnActionReceived(ActionBuffers actions) { // Actions size = 2 [vertical speed, horizontal speed] = [-1..1, -1..1] // discrete = [{0, 1, 2}, {0, 1, 2}] = [{-1, 0, 1}...] float vInput = 0; float hInput = 0; if (actions.DiscreteActions[0] == 0) vInput = -1f; if (actions.DiscreteActions[0] == 1) vInput = 1f; if (actions.DiscreteActions[1] == 0) hInput = -1f; if (actions.DiscreteActions[1] == 1) hInput = 1f; float forwardSpeed = Vector3.Dot(transform.forward, rigidBody.velocity); // Calculate how close the car is to top speed // as a number from zero to one float speedFactor = Mathf.InverseLerp(0, maxSpeed / 4, forwardSpeed); // Use that to calculate how much torque is available // (zero torque at top speed) float currentMotorTorque = Mathf.Lerp(motorTorque, 0, speedFactor); // …and to calculate how much to steer // (the car steers more gently at top speed) float currentSteerRange = Mathf.Lerp(steeringRange, steeringRangeAtMaxSpeed, speedFactor); // Check whether the user input is in the same direction // as the car's velocity bool isAccelerating = Mathf.Sign(vInput) == Mathf.Sign(forwardSpeed); bool isStopping = vInput == 0; // range bool isBraking = (vInput < 0 && forwardSpeed > 0) || (vInput > 0 && forwardSpeed < 0); if (vInput > 0 && forwardSpeed < 0) { isAccelerating = false; } foreach (var wheel in wheels) { // Apply steering to Wheel colliders that have "Steerable" enabled if (wheel.steerable) { wheel.WheelCollider.steerAngle = hInput * currentSteerRange; } if (isBraking) { wheel.WheelCollider.brakeTorque = Mathf.Abs(vInput) * brakeTorque; //wheel.WheelCollider.motorTorque = 0; } if (isAccelerating) { // Apply torque to Wheel colliders that have "Motorized" enabled if (wheel.motorized) { wheel.WheelCollider.motorTorque = vInput * currentMotorTorque; } wheel.WheelCollider.brakeTorque = 0; } if (isStopping) { // If the user is trying to go in the opposite direction // apply brakes to all wheels wheel.WheelCollider.brakeTorque = Mathf.Abs(vInput) * brakeTorque + autoBrake; if (forwardSpeed < 0) { wheel.WheelCollider.brakeTorque = (Mathf.Abs(vInput) * brakeTorque + autoBrake) * 5; } // wheel.WheelCollider.motorTorque = 0; } } // rewards float distanceToTarget = Vector3.Distance(transform.localPosition, Target.localPosition); if (distanceToTarget < 0.5f) { SetReward(1.0f); EndEpisode(); } } public override void Heuristic(in ActionBuffers actionsOut) { var discreteActionsOut = actionsOut.DiscreteActions; Debug.Log(Input.GetAxis("Vertical")); discreteActionsOut[0] = 2; discreteActionsOut[1] = 2; if (Input.GetAxis("Vertical") < -0.5) discreteActionsOut[0] = 0; if (Input.GetAxis("Vertical") > 0.5) discreteActionsOut[0] = 1; if (Input.GetAxis("Horizontal") < -0.5) discreteActionsOut[1] = 0; if (Input.GetAxis("Horizontal") > 0.5) discreteActionsOut[1] = 1; } // // Update is called once per frame // void FixedUpdate() // { // float vInput = Input.GetAxis("Vertical"); // float hInput = Input.GetAxis("Horizontal"); // // Calculate current speed in relation to the forward direction of the car // // (this returns a negative number when traveling backwards) // float forwardSpeed = Vector3.Dot(transform.forward, rigidBody.velocity); // // Calculate how close the car is to top speed // // as a number from zero to one // float speedFactor = Mathf.InverseLerp(0, maxSpeed / 4, forwardSpeed); // // Use that to calculate how much torque is available // // (zero torque at top speed) // float currentMotorTorque = Mathf.Lerp(motorTorque, 0, speedFactor); // // …and to calculate how much to steer // // (the car steers more gently at top speed) // float currentSteerRange = Mathf.Lerp(steeringRange, steeringRangeAtMaxSpeed, speedFactor); // // Check whether the user input is in the same direction // // as the car's velocity // bool isAccelerating = Mathf.Sign(vInput) == Mathf.Sign(forwardSpeed); // bool isStopping = vInput == 0; // range // bool isBraking = (vInput < 0 && forwardSpeed > 0) || (vInput > 0 && forwardSpeed < 0); // if (vInput > 0 && forwardSpeed < 0) // { // isAccelerating = false; // } // foreach (var wheel in wheels) // { // // Apply steering to Wheel colliders that have "Steerable" enabled // if (wheel.steerable) // { // wheel.WheelCollider.steerAngle = hInput * currentSteerRange; // } // if (isBraking) // { // wheel.WheelCollider.brakeTorque = Mathf.Abs(vInput) * brakeTorque; // //wheel.WheelCollider.motorTorque = 0; // } // if (isAccelerating) // { // // Apply torque to Wheel colliders that have "Motorized" enabled // if (wheel.motorized) // { // wheel.WheelCollider.motorTorque = vInput * currentMotorTorque; // } // wheel.WheelCollider.brakeTorque = 0; // } // if (isStopping) // { // // If the user is trying to go in the opposite direction // // apply brakes to all wheels // wheel.WheelCollider.brakeTorque = Mathf.Abs(vInput) * brakeTorque + autoBrake; // if (forwardSpeed < 0) // { // wheel.WheelCollider.brakeTorque = (Mathf.Abs(vInput) * brakeTorque + autoBrake) * 5; // } // // wheel.WheelCollider.motorTorque = 0; // } // } // } }