実現したいこと
複数の3次元点群を、それぞれある位置中心にモデルのパラメータであるクォータニオンによって回転させ、返すモデルを作成しました。そして、このモデルを用いて、入力である複数の3次元点群がモデル点群に近づくよう教師あり機械学習させ、パラメータであるクォータニオンを推定しようとしています。損失関数は2つの3次元点群の距離を計算するChamfer distanceを用います。また、入力する3次元点群は列になって並んでおり、Chamfer distanceが最小になるよう自由に最適化をするとモデル点群の一箇所に集中してしまい、所望の結果から異なる結果になってしまいます。そのため、隣り合う3次元点群の回転の差は小さいことを最適化の制約条件(ペナルティ項)として損失関数に加えます。このペナルティ項はモデルのパラメータである各点群の回転を示すクォータニオンから、隣接するクォータニオンを抽出し相対回転の回転角(ロドリゲスの回転公式)で計算しようとしています。
文字情報だけだと分かりづらいと思いますので、イメージ図を添付いたします。
発生している問題・分からないこと
しかし、ペナルティ項の計算の際にモデルのパラメータであるクォータニオンに勾配が通らない問題が発生しています。この問題を解決して、先述したことを実現したいです。
作成したコードは次の通りで、ChatGPTに質問しながら作成しました。
該当のソースコード
python
1import numpy as np 2import open3d as o3d 3import math 4import torch 5from torch import nn 6from pytorch3d.transforms import quaternion_to_matrix 7from pytorch3d.transforms import quaternion_multiply 8 9device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 10 11class Estimate_Param(nn.Module): 12 def __init__(self, num): 13 super().__init__() 14 self.num = num 15 self.q = nn.Parameter(torch.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True).repeat(num, 1)) 16 17 def rotation_matrix_to_axis_angle(self, R): 18 # 回転角を計算 19 trace_R = torch.trace(R).to(device) 20 theta = torch.acos((trace_R - 1) / 2).to(device) 21 # 回転軸を計算 (正規化) 22 vx = R[2, 1] - R[1, 2] 23 vy = R[0, 2] - R[2, 0] 24 vz = R[1, 0] - R[0, 1] 25 26 axis = torch.tensor([vx, vy, vz]) 27 axis = axis / torch.norm(axis) # 正規化 28 29 return axis, theta 30 31 def forward(self, pcd_list, scan_pos_list): 32 output_list = [] 33 axis_list = [] 34 angle_list = [] 35 for i in range(self.num): 36 pcd_i = pcd_list[i] 37 scan_i = scan_pos_list[i] 38 39 pcd_i_translated = pcd_i - scan_i 40 # pcd_i_translated を float32 にキャスト 41 pcd_i_translated = pcd_i_translated.float() 42 43 q = self.q[i] / torch.norm(self.q[i]) 44 rotation_matrix = quaternion_to_matrix(q.unsqueeze(0)).squeeze(0) 45 pcd_i_rotated_q = torch.matmul(pcd_i_translated, rotation_matrix.T) 46 47 transformed_data = pcd_i_rotated_q + scan_i 48 49 output_list.append(transformed_data) 50 axis, angle = self.rotation_matrix_to_axis_angle(rotation_matrix) 51 axis_list.append(axis) 52 angle_list.append(angle) 53 54 return output_list, axis_list, angle_list 55 56# 相対回転角を計算する関数 57from pytorch3d.transforms import quaternion_multiply 58def relative_rotation_angle(q1, q2): 59 # q1とq2の相対クォータニオンを計算 (q1の逆数とq2の積) 60 q1_inv = torch.tensor([q1[0], -q1[1], -q1[2], -q1[3]]) 61 q_rel = quaternion_multiply(q1_inv, q2) 62 63 # 相対回転角を計算 64 theta = 2 * torch.acos(q_rel[0].clamp(-1.0, 1.0)) 65 # クランプして数値範囲を保つ 66 return theta 67 68# クォータニオンリストから全ての相対回転角を計算 69def calculate_penalty(model_q): 70 num = model_q.size(0) 71 penalty = torch.tensor(0.0, device=device) 72 73 # 隣接するクォータニオン間の相対回転角を計算 74 for i in range(1, num): 75 q1 = model_q[i - 1] / torch.norm(model_q[i - 1]) 76 q2 = model_q[i] / torch.norm(model_q[i]) 77 theta = relative_rotation_angle(q1, q2) 78 penalty += theta 79 80 return penalty 81 82model = Estimate_Param(num=num_lines).to(device=device) 83optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 84# モデルへの入力データ 85input_pcd_list = "列に並んだ複数の3次元点群座標が入ったtorch.tensorを格納するlist. [torch.tensor(N_1, 3), torch.tensor(N_2, 3), ..., torch.tensor(N_n, 3)]" 86input_scan_list = "input_pcd_listの各3次元点群のクォータニオンで回転する際の回転中心位置が入ったtorch.tensor. torch.tensor(n, 3)" 87# 損失関数計算の際に用いるモデル点群 88target = "モデル点群の座標が入ったtorch.tensor.detach(). モデル点群座標が学習ごとに変化してはいけないのでdetachしている" 89 90iteration = 100 91for epoch in range(iteration): 92 optimizer.zero_grad() 93 output_list, axis_list, angle_list = model(input_pcd_list, input_scan_pos_list) 94 sum_AB = 0 95 # penalty = torch.tensor(0.0, device=device) 96 for i in range(num_lines): 97 output = output_list[i] 98 chamfer_loss, chamfer_AB, chamfer_BA = Chamfer_distance(output, target) 99 sum_AB += chamfer_AB 100 penalty = calculate_penalty(model.q) 101 102 loss = sum_AB + penalty 103 loss.backward() 104 optimizer.step() 105 torch.cuda.empty_cache() 106 print(epoch + 1) 107 print("model.q.grad", model.q.grad) 108 print("loss", loss.item()) 109 print("chamfer_AB", sum_AB.item()) 110 print("penalty", penalty.item())
試したこと・調べたこと
- teratailやGoogle等で検索した
- ソースコードを自分なりに変更した
- 知人に聞いた
- その他
上記の詳細・結果
まず、コード最後にあるようにlossが正しく計算されているか、model.qに勾配が通っているか確認すると、以下のようになりました。
1
model.q.grad : tensor([[ 0.0000, 179.8098, -61.2189, -86.0011],
[ nan, nan, nan, nan],
[ nan, nan, nan, nan],
・・・
[ nan, nan, nan, nan]], device='cuda:0')
loss : 131.84036254882812
chamfer_AB : 131.84036254882812
penalty : 0.0
2
model.q.grad : tensor([[ -5.8747, -140.5616, 204.2349, 242.6726],
[ nan, nan, nan, nan],
[ nan, nan, nan, nan],
・・・
[ nan, nan, nan, nan]], device='cuda:0')
loss : nan
chamfer_AB : nan
penalty : nan
3
model.q.grad : tensor([[ -1.2400, -177.3816, -70.0906, -59.6504],
[ nan, nan, nan, nan],
[ nan, nan, nan, nan],
・・・
[ nan, nan, nan, nan]], device='cuda:0')
loss : nan
chamfer_AB : nan
penalty : nan
・・・
次に、loss = sum_ABとし、penalty項をlossに含めなかったところ
1
model.q.grad : tensor([[ 0.0000e+00, 1.7981e+02, -6.1219e+01, -8.6001e+01],
[ 0.0000e+00, 1.6622e+02, -9.8284e+01, -1.2649e+02],
[ 0.0000e+00, 8.4467e+01, -2.0669e+02, -2.4092e+02],
・・・
[ 0.0000e+00, -3.1394e+01, 2.8565e+01, 2.5363e+01]], device='cuda:0')
loss : 131.9658660888672
chamfer_AB : 131.9658660888672
penalty : 0.0
2
model.q.grad : tensor([[-5.8747e+00, -1.4056e+02, 2.0423e+02, 2.4267e+02],
[-7.7819e+00, -5.0197e+02, 1.0464e+02, 1.7157e+02],
[-3.8115e+00, -2.3785e+02, 5.6516e+01, 8.6785e+01],
・・・
[-1.7411e+00, 7.6988e+01, -5.1253e+01, -4.5868e+01]], device='cuda:0')
loss : 698.0641479492188
chamfer_AB : 698.0641479492188
penalty : 6.798530578613281
3
model.q.grad : tensor([[-1.2400e+00, -1.7738e+02, -7.0091e+01, -5.9650e+01],
[-1.9538e+00, -2.1566e+02, 3.8415e+01, 6.6295e+01],
[ 4.0716e+00, -3.6200e+02, -2.1852e+02, -2.0509e+02],
[ 2.4014e-01, -5.5282e+01, -7.5611e+00, -2.7705e+00]], device='cuda:0')
loss : 238.53053283691406
chamfer_AB : 238.53053283691406
penalty : 5.056334018707275
・・・
というようにlossの計算ができました。原因はpenalty項にあることは分かりましたが、model.q[0]のみ勾配が通りその他がnanになる原因が分かりません。アドバイスよろしくお願いいたします。
補足
ご回答にイメージ図を追加してほしいとあったので追記いたします。
1枚目が実現したいことを図示したもので、
2枚目が補足説明したものになります。
回答2件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。