ML-agentsでボール投げのAIを作っています。
最初はボールをくっつけてなんとかしてくれるのですが、2回目からjointが二つになってしまい、3回目以降からjointが4つになってしまうため、うまく動作ができません。
「なんとか複製されてしまうのを防げれば」光が見えるのですが、知識、ヒントをお持ちの方いないでしょうか。
今は手の3dラグドールの手の先端に、スフィアをくっつけて、-秒後に勝手に離れるので投げてくれ!という感じにしてます。
3回目に空中に固定されたボールに腕だけが接着し、最終的にラグドールがorzという形になって終わります。
至急答えてくれると幸いです。三回endepisodeが繰り返されてるのではないかという気がするのですが...
Walker agent.cs
using System;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgentsExamples;
using Unity.MLAgents.Sensors;
using BodyPart = Unity.MLAgentsExamples.BodyPart;
using Random = UnityEngine.Random;
[DisallowMultipleComponent]
public class WalkerAgent : Agent
{
public bool randomizeThrowSpeedEachEpisode;
private Vector3 m_WorldDirToThrow = Vector3.right;
[Header("Target To Throw Towards")] public Transform target; [Header("Body Parts")] public Transform hips; public Transform chest; public Transform spine; public Transform head; public Transform thighL; public Transform shinL; public Transform footL; public Transform thighR; public Transform shinR; public Transform footR; public Transform armL; public Transform forearmL; public Transform handL; public Transform armR; public Transform forearmR; public Transform handR; OrientationCubeController m_OrientationCube; DirectionIndicator m_DirectionIndicator; JointDriveController m_JdController; EnvironmentParameters m_ResetParams; GameObject m_Sphere; GameObject m_target; GameObject hand_R; Vector3 t_Sphere; float byou; public override void Initialize() { m_OrientationCube = GetComponentInChildren<OrientationCubeController>(); m_DirectionIndicator = GetComponentInChildren<DirectionIndicator>(); m_Sphere = transform.Find("Sphere").gameObject; m_target = GameObject.Find("DynamicTarget"); hand_R = transform.Find("hips/spine/chest/upper_arm_R/lower_arm_R/hand_R").gameObject; t_Sphere = m_Sphere.transform.localPosition; m_JdController = GetComponent<JointDriveController>(); m_JdController.SetupBodyPart(hips); m_JdController.SetupBodyPart(chest); m_JdController.SetupBodyPart(spine); m_JdController.SetupBodyPart(head); m_JdController.SetupBodyPart(thighL); m_JdController.SetupBodyPart(shinL); m_JdController.SetupBodyPart(footL); m_JdController.SetupBodyPart(thighR); m_JdController.SetupBodyPart(shinR); m_JdController.SetupBodyPart(footR); m_JdController.SetupBodyPart(armL); m_JdController.SetupBodyPart(forearmL); m_JdController.SetupBodyPart(handL); m_JdController.SetupBodyPart(armR); m_JdController.SetupBodyPart(forearmR); m_JdController.SetupBodyPart(handR); m_ResetParams = Academy.Instance.EnvironmentParameters; SetResetParameters(); } public override void OnEpisodeBegin() { foreach (var bodyPart in m_JdController.bodyPartsDict.Values) { bodyPart.Reset(bodyPart); } byou = 0f; m_Sphere.transform.localPosition = t_Sphere; m_Sphere.gameObject.AddComponent<FixedJoint>(); m_Sphere.gameObject.GetComponent<FixedJoint>().connectedBody = hand_R.gameObject.GetComponent<Rigidbody>(); UpdateOrientationObjects(); SetResetParameters(); } public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor) { sensor.AddObservation(bp.groundContact.touchingGround); sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(bp.rb.velocity)); sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(bp.rb.velocity)); sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(bp.rb.angularVelocity)); sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(bp.rb.position - hips.position)); if (bp.rb.transform != hips && bp.rb.transform != handL && bp.rb.transform != handR) { sensor.AddObservation(bp.rb.transform.localRotation); sensor.AddObservation(bp.currentStrength / m_JdController.maxJointForceLimit); } } public override void CollectObservations(VectorSensor sensor) { var cubeForward = m_OrientationCube.transform.forward; //rotation deltas sensor.AddObservation(Quaternion.FromToRotation(hips.forward, cubeForward)); sensor.AddObservation(Quaternion.FromToRotation(head.forward, cubeForward)); sensor.AddObservation(m_OrientationCube.transform.InverseTransformPoint(target.transform.position)); foreach (var bodyPart in m_JdController.bodyPartsList) { CollectObservationBodyPart(bodyPart, sensor); } } public override void OnActionReceived(ActionBuffers actionBuffers) { var bpDict = m_JdController.bodyPartsDict; var i = -1; var continuousActions = actionBuffers.ContinuousActions; bpDict[chest].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], continuousActions[++i]); bpDict[spine].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], continuousActions[++i]); bpDict[thighL].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0); bpDict[thighR].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0); bpDict[shinL].SetJointTargetRotation(continuousActions[++i], 0, 0); bpDict[shinR].SetJointTargetRotation(continuousActions[++i], 0, 0); bpDict[footR].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], continuousActions[++i]); bpDict[footL].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], continuousActions[++i]); bpDict[armL].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0); bpDict[armR].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0); bpDict[forearmL].SetJointTargetRotation(continuousActions[++i], 0, 0); bpDict[forearmR].SetJointTargetRotation(continuousActions[++i], 0, 0); bpDict[head].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0); //update joint strength settings bpDict[chest].SetJointStrength(continuousActions[++i]); bpDict[spine].SetJointStrength(continuousActions[++i]); bpDict[head].SetJointStrength(continuousActions[++i]); bpDict[thighL].SetJointStrength(continuousActions[++i]); bpDict[shinL].SetJointStrength(continuousActions[++i]); bpDict[footL].SetJointStrength(continuousActions[++i]); bpDict[thighR].SetJointStrength(continuousActions[++i]); bpDict[shinR].SetJointStrength(continuousActions[++i]); bpDict[footR].SetJointStrength(continuousActions[++i]); bpDict[armL].SetJointStrength(continuousActions[++i]); bpDict[forearmL].SetJointStrength(continuousActions[++i]); bpDict[armR].SetJointStrength(continuousActions[++i]); bpDict[forearmR].SetJointStrength(continuousActions[++i]); } void UpdateOrientationObjects() { m_WorldDirToThrow = target.position - hips.position; m_OrientationCube.UpdateOrientation(hips, target); if (m_DirectionIndicator) { m_DirectionIndicator.MatchOrientation(m_OrientationCube.transform); } } void FixedUpdate() { UpdateOrientationObjects(); var cubeForward = m_OrientationCube.transform.forward; var lookAtTargetReward = (Vector3.Dot(cubeForward, head.forward) + 1) * .5F; if (float.IsNaN(lookAtTargetReward)) { throw new ArgumentException( "NaN in lookAtTargetReward.\n" + $" cubeForward: {cubeForward}\n" + $" head.forward: {head.forward}" ); } AddReward(lookAtTargetReward); byou += Time.deltaTime; Destroy(m_Sphere.gameObject.GetComponent<FixedJoint>(),1.0f); if(byou >= 4){ if(m_Sphere.GetComponent<Rigidbody>().velocity==Vector3.zero){ Vector3 posA = m_Sphere.transform.position; Vector3 posB = m_target.transform.position; float dis = Vector3.Distance(posA,posB); Vector3 posC = m_OrientationCube.transform.position; float kyori = Vector3.Distance(posC,posB); float res = kyori-dis; AddReward(res*0.01f); EndEpisode(); } } } public void SetTorsoMass() { m_JdController.bodyPartsDict[chest].rb.mass = m_ResetParams.GetWithDefault("chest_mass", 8); m_JdController.bodyPartsDict[spine].rb.mass = m_ResetParams.GetWithDefault("spine_mass", 8); m_JdController.bodyPartsDict[hips].rb.mass = m_ResetParams.GetWithDefault("hip_mass", 8); } public void SetResetParameters() { SetTorsoMass(); }
}
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。