現在,Pytorchで実装されていた下記の関数をChainerに実装することに取り組んでいます.
Pytochのscatter_(=scatter_add)関数はChainerのscatter_add関数と挙動が異なるため,
この部分のみ単純に置き換えることが出来ません.
同様な処理となるように書き換えたいのですが,良い案が思いつきません.
何か良い方法があればご教示頂けますか?
宜しくお願い致します.
実装したい関数
class DiceLoss(_WeightedLoss): """ Dice Loss for a batch of samples """ def forward(self, output, target, weights=None, ignore_index=None): """ Forward pass :param output: NxCxHxW Variable :param target: NxHxW LongTensor :param weights: C FloatTensor :param ignore_index: int index to ignore from loss :return: """ eps = 0.0001 encoded_target = output.detach() * 0 if ignore_index is not None: mask = target == ignore_index target = target.clone() target[mask] = 0 encoded_target.scatter_(1, target.unsqueeze(1), 1) mask = mask.unsqueeze(1).expand_as(encoded_target) encoded_target[mask] = 0 else: encoded_target.scatter_(1, target.unsqueeze(1), 1) if weights is None: weights = 1 intersection = output * encoded_target numerator = 2 * intersection.sum(0).sum(1).sum(1) denominator = output + encoded_target if ignore_index is not None: denominator[mask] = 0 denominator = denominator.sum(0).sum(1).sum(1) + eps loss_per_channel = weights * (1 - (numerator / denominator)) return loss_per_channel.sum() / output.size(1)
scatter_(dim, index, other) [Pytorch]
- dim (int) – the axis along which to index
- index (LongTensor) – the indices of elements to scatter and add
- other (Tensor) – the source elements to scatter and add
>>> x = torch.rand(2, 5) >>> x tensor([[ 0.3992, 0.2908, 0.9044, 0.4850, 0.6004], [ 0.5735, 0.9006, 0.6797, 0.4152, 0.1732]]) >>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) tensor([[ 0.3992, 0.9006, 0.6797, 0.4850, 0.6004], [ 0.0000, 0.2908, 0.0000, 0.4152, 0.0000], [ 0.5735, 0.0000, 0.9044, 0.0000, 0.1732]]) >>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23) >>> z tensor([[ 0.0000, 0.0000, 1.2300, 0.0000], [ 0.0000, 0.0000, 0.0000, 1.2300]])
scatter_add(a, slices, b) [Chainer or CuPy]
- a (Variable) A variable.
- slices (int, slice, Ellipsis, None, integer array-like, boolean array-like or tuple of them)
It is an integer, a slice, an ellipsis, a numpy.newaxis, an integer array-like, a boolean array-like or tuple of them.
- b (Variable) A variable that is scatter added to a. Its shape has to equal a[slices] because broadcasting of variables is not supported.
>>> import numpy >>> import cupy >>> a = cupy.zeros((6,), dtype=numpy.float32) >>> i = cupy.array([1, 0, 1]) >>> v = cupy.array([1., 1., 1.]) >>> cupyx.scatter_add(a, i, v); >>> a array([1., 2., 0., 0., 0., 0.], dtype=float32)
あなたの回答
tips
プレビュー