In tensorflow, can you use non-smooth function as loss function, such as piece-wise (or with if-else)? If you cant, why you can use ReLU?
In this link S
tf does not compute gradients for all functions automatically, even if one uses some backend functions. Please see. Errors when Building up a Custom Loss Function for a task I did, then I found out the answer myself.
That being said, one may only approximate a piece-wise differentiable functions so as to implement, for example, piece-wise constant/step functions. The following is my implementation as per such an idea in MATLAB. One may easily extend it to cases with more thresholds (junctures) and desire boundary conditions.
function [s, ds] = QPWC_Neuron(z, sharp)
% A special case of (quadraple) piece-wise constant neuron composing of three Sigmoid functions
% There are three thresholds (junctures), 0.25, 0.5, and 0.75, respectively
% sharp determines how steep steps are between two junctures.
% The closer a point to one of junctures, the smaller its gradient will become. Gradients at junctures are zero.
% It deals with 1D signal only are present, and it must be preceded by another activation function, the output from which falls within [0, 1]
% Example:
% z = 0:0.001:1;
% sharp = 100;
LZ = length(z);
s = zeros(size(z));
ds = s;
for l = 1:LZ
if z(l) <= 0
s(l) = 0;
ds(l) = 0;
elseif (z(l) > 0) && (z(l) <= 0.25)
s(l) = 0.25 ./ (1+exp(-sharp*((z(l)-0.125)./0.25)));
ds(l) = sharp/0.25 * (s(l)-0) * (1-(s(l)-0)/0.25);
elseif (z(l) > 0.25) && (z(l) <= 0.5)
s(l) = 0.25 ./ (1+exp(-sharp*((z(l)-0.375)./0.25))) + 0.25;
ds(l) = sharp/0.25 * (s(l)-0.25) * (1-(s(l)-0.25)/0.25);
elseif (z(l) > 0.5) && (z(l) <= 0.75)
s(l) = 0.25 ./ (1+exp(-sharp*((z(l)-0.625)./0.25))) + 0.5;
ds(l) = sharp/0.25 * (s(l)-0.5) * (1-(s(l)-0.5)/0.25);
elseif (z(l) > 0.75) && (z(l) < 1)
% If z is larger than 0.75, the gradient shall be descended to it faster than other cases
s(l) = 0.5 ./ (1+exp(-sharp*((z(l)-1)./0.5))) + 0.75;
ds(l) = sharp/0.5 * (s(l)-0.75) * (1-(s(l)-0.75)/0.5);
else
s(l) = 1;
ds(l) = 0;
end
end
figure;
subplot 121, plot(z, s); xlim([0, 1]);grid on;
subplot 122, plot(z, ds); xlim([0, 1]);grid on;
end