I've been messing around with a simple tensor library, in which I have defined the following type.
data Tensor : Vect n Nat -> Type -> Type where
Scalar : a -> Tensor [] a
Dimension : Vect n (Tensor d a) -> Tensor (n :: d) a
The vector parameter of the type describes the tensor's "dimensions" or "shape". I am currently trying to define a function to safely index into a Tensor
. I had planned to do this using Fin
s but I ran into an issue. Because the Tensor
is of unknown order, I could need any number of indices, each of which requiring a different upper bound. This means that a Vect
of indices would be insufficient, because each index would have a different type. That drove me to look at using tuples (called "pairs" in Idris?) instead. I wrote the following function to compute the necessary type.
TensorIndex : Vect n Nat -> Type
TensorIndex [] = ()
TensorIndex (d::[]) = Fin d
TensorIndex (d::ds) = (Fin d, TensorIndex ds)
This function worked as I expected, calculating the appropriate index type from a dimension vector.
> TensorIndex [4,4,3] -- (Fin 4, Fin 4, Fin 3)
> TensorIndex [2] -- Fin 2
> TensorIndex [] -- ()
But when I tried to define the actual index
index : {d : Vect n Nat} -> TensorIndex d -> Tensor d a -> a
index () (Scalar x) = x
index (a,as) (Dimension xs) = index as $ index a xs
index a (Dimension xs) with (index a xs) | Tensor x = x
...Idris raised the following error on the second case (oddly enough it seemed perfectly okay with the first).
Type mismatch between
(A, B) (Type of (a,as))
TensorIndex (n :: d) (Expected type)
The error seems to imply that instead of treating TensorIndex
as an extremely convoluted type synonym and evaluating it like I had hoped it would, it treated it as though it were defined with a data
declaration; a "black-box type" so to speak. Where does Idris draw the line on this? Is there some way for me to rewrite TensorIndex
so that it works the way I want it to? If not, can you think of any other way to write the index
Your life becomes so much easier if you allow for a trailing ()
in your TensorIndex
, since then you can just do
TensorIndex : Vect n Nat -> Type
TensorIndex [] = ()
TensorIndex (d::ds) = (Fin d, TensorIndex ds)
index : {ds : Vect n Nat} -> TensorIndex ds -> Tensor ds a -> a
index {ds = []} () (Scalar x) = x
index {ds = _ :: ds} (i, is) (Dimension xs) = index is (index i xs)
If you want to keep your definition of TensorIndex
, you'll need to have separate cases for ds = [_]
and ds = _::_::_
to match the structure of TensorIndex
TensorIndex : Vect n Nat -> Type
TensorIndex [] = ()
TensorIndex (d::[]) = Fin d
TensorIndex (d::ds) = (Fin d, TensorIndex ds)
index : {ds : Vect n Nat} -> TensorIndex ds -> Tensor ds a -> a
index {ds = []} () (Scalar x) = x
index {ds = _ :: []} i (Dimension xs) with (index i xs) | (Scalar x) = x
index {ds = _ :: _ :: _} (i, is) (Dimension xs) = index is (index i xs)
The reason this works and yours didn't is because here, each case of index
corresponds exactly to one TensorIndex
case, and so TensorIndex ds
can be reduced.
Your definitions will be cleaner if you define Tensor
by induction over the list of dimensions whilst the Index
is defined as a datatype.
Indeed, at the moment you are forced to pattern-match on the implicit argument of type Vect n Nat
to see what shape the index has. But if the index is defined directly as a piece of data, it then constrains the shape of the structure it indexes into and everything falls into place: the right piece of information arrives at the right time for the typechecker to be happy.
module Tensor
import Data.Fin
import Data.Vect
tensor : Vect n Nat -> Type -> Type
tensor [] a = a
tensor (m :: ms) a = Vect m (tensor ms a)
data Index : Vect n Nat -> Type where
Here : Index []
At : Fin m -> Index ms -> Index (m :: ms)
index : Index ms -> tensor ms a -> a
index Here a = a
index (At k i) v = index i $ index k v