posmax: like argmax but gives the position(s) of the element x for which f[x] is maximal

旧时模样 提交于 2019-12-06 12:43:40

@dreeves, you're correct in that Ordering is the key to the fastest implementation of ArgMax over a finite domain:

ArgMax[f_, dom_List] := dom[[Ordering[f /@ dom, -1]]]

Part of the problem with your original implementation using Fold is that you end up evaluating f twice as much as necessary, which is inefficient, especially when computing f is slow. Here we only evaluate f once for each member of the domain. When the domain has many duplicated elements, we can further optimize by memoizing the values of f:

ArgMax[f_, dom_List] :=
  Module[{g},
    g[e___] := g[e] = f[e]; (* memoize *)
    dom[[Ordering[g /@ dom, -1]]]
  ]

This was about 30% faster in some basic tests for a list of 100,000 random integers between 0 and 100.

For a posmax function, this somewhat non-elegant approach is the fastest thing I can come up with:

PosMax[f_, dom_List] :=
  Module[{y = f/@dom},
    Flatten@Position[y, Max[y]]
  ]

Of course, we can apply memoization again:

PosMax[f_, dom_List] := 
  Module[{g, y},
    g[e___] := g[e] = f[e];
    y = g /@ dom;
    Flatten@Position[y, Max[y]]
  ]

To get all the maximal elements, you could now just implement ArgMax in terms of PosMax:

ArgMax[f_, dom_List] := dom[[PosMax[f, dom]]]

For posmax, you can first map the function over the list and then just ask for the position of the maximal element(s). Ie:

posmax[f_, dom_List] := posmax[f /@ dom]

where posmax[list] is polymorphically defined to just return the position of the maximal element(s). It turns out there's a built-in function, Ordering that essentially does this. So we can define the single-argument version of posmax like this:

posmax[dom_List] := Ordering[dom, -1][[1]]

I just tested that against a loop-based version and a recursive version and Ordering is many times faster. The recursive version is pretty so I'll show it off here, but don't ever try to run it on large inputs!

(* posmax0 is a helper function for posmax that returns a pair with the position 
   and value of the max element. n is an accumulator variable, in lisp-speak. *)
posmax0[{h_}, n_:0] := {n+1, h}
posmax0[{h_, t___}, n_:0] := With[{best = posmax0[{t}, n+1]},
  If[h >= best[[2]], {n+1, h}, best]]

posmax[dom_List] := First@posmax0[dom, 0]
posmax[f_, dom_List] := First@posmax0[f /@ dom, 0]
posmax[_, {}] := 0

None of this addresses the question of how to find all the maximal elements (or positions of them). That doesn't normally come up for me in practice, though I think it would be good to have.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!