Matlab: Argmax and dot product for each row in a matrix

前端 未结 2 1682
春和景丽
春和景丽 2021-01-21 07:09

I have 2 matrices = X in R^(n*m) and W in R^(k*m) where k<. Let x_i be the i-th row of X and w_j be the

相关标签:
2条回答
  • 2021-01-21 07:26

    Dot product is essentially matrix multiplication:

    [~, Y] = max(W*X');
    
    0 讨论(0)
  • 2021-01-21 07:33

    bsxfun based approach to speed-up things for you -

    [~,Y] = max(sum(bsxfun(@times,X,permute(W,[3 2 1])),2),[],3)
    

    On my system, using your dataset I am getting a 100x+ speedup with this.


    One can think of two more "closeby" approaches, but they don't seem to give any huge improvement over the earlier one -

    [~,Y] = max(squeeze(sum(bsxfun(@times,X,permute(W,[3 2 1])),2)),[],2)
    

    and

    [~,Y] = max(squeeze(sum(bsxfun(@times,X',permute(W,[2 3 1]))))')
    
    0 讨论(0)
提交回复
热议问题