Self Training Algorithm

后端 未结 1 1188
故里飘歌
故里飘歌 2021-02-03 12:40

I\'d like to develop a self training algorithm for a specific problem. To keep things simple i\'ll nail it down to simple example.

Update: I have added a working

相关标签:
1条回答
  • 2021-02-03 13:31

    After having a rest i came up with a solution that seems to fit my requirements. The limitation is that all tested properties should be of the same type with the same value range, which is fine for me in my case because all properties are abstract percentage values.

    By the way, i'm not sure if the topic "self training algorithm" is a little bit misleading here. There are a couple of ways to implement such a solution, but if you have no idea how your data behave and which effects the values have, the most simple solution is to brute force all possible combinations to identify the best fitting result. That's what i'm showing here.

    Anyways, for testing purpose i added a random number generator to my entity class.

    public class Entity
    {
        public byte Prop1 { get; set; }
        public byte Prop2 { get; set; }
        public byte Prop3 { get; set; }
        public byte Prop4 { get; set; }
    
        public Entity()
        {
            Random random = new Random( Guid.NewGuid().GetHashCode() );
            byte[] bytes = new byte[ 4 ];
    
            random.NextBytes( bytes );
    
            this.Prop1 = bytes[0];
            this.Prop2 = bytes[1];
            this.Prop3 = bytes[2];
            this.Prop4 = bytes[3];
        }
    }
    

    My bitmask stays untouched.

    [Flags]
    public enum EProperty
    {
        Undefined = 0,
        Prop1 = 1,
        Prop2 = 1 << 1,
        Prop3 = 1 << 2,
        Prop4 = 1 << 3
    }
    

    Than i added some new extension methodes to deal with my bitmask.

    public static class BitMask
    {
        public static int GetMaxValue<T>() where T : struct
        {
            return Enum.GetValues(typeof (T)).Cast<int>().Sum();
        }
    
        public static int GetTotalCount<T>() where T : struct
        {
            return Enum.GetValues(typeof (T)).Cast<int>().Count(e => e > 0);
        }
    
        public static int GetFlagCount<T>(this T mask) where T : struct
        {
            int result = 0, value = (int) (object) mask;
    
            while (value != 0)
            {
                value = value & (value - 1);
                result++;
            }
    
            return result;
        }
    
        public static IEnumerable<T> Split<T>(this T mask)
        {
            int maskValue = (int) (object) mask;
    
            foreach (T flag in Enum.GetValues(typeof (T)))
            {
                int flagValue = (int) (object) flag;
    
                if (0 != (flagValue & maskValue))
                {
                    yield return flag;
                }
            }
        }
    }
    

    Than i wrote a query builder

    public static class QueryBuilder
    {
        public static IEnumerable<Entity> Where(this IEnumerable<Entity> entities, EProperty[] properties, int[] values)
        {
            IEnumerable<Entity> result = entities.Select(e => e);
    
            for (int index = 0; index <= properties.Length - 1; index++)
            {
                EProperty property = properties[index];
                int value = values[index];
    
                switch (property)
                {
                    case EProperty.Prop1:
                        result = result.Where(e => Math.Abs(e.Prop1) >= value);
                        break;
                    case EProperty.Prop2:
                        result = result.Where(e => Math.Abs(e.Prop2) >= value);
                        break;
                    case EProperty.Prop3:
                        result = result.Where(e => Math.Abs(e.Prop3) >= value);
                        break;              
                    case EProperty.Prop4:
                        result = result.Where(e => Math.Abs(e.Prop4) >= value);
                        break;   
                }
            }
    
            return result;
        }
    }
    

    And finally i'm ready to run the training.

        private const int maxThreads = 10;
    
        private const int minValue = 0;
        private const int maxValue = 100;
    
        private static IEnumerable<Entity> entities;
    
        public static void Main(string[] args)
        {
            Console.WriteLine(DateTime.Now.ToLongTimeString());
    
            entities = Enumerable.Repeat(new Entity(), 10).ToList();
    
            Action<EProperty[], int[]> testCase = RunTestCase;
            RunSelfTraining( testCase );
    
            Console.WriteLine(DateTime.Now.ToLongTimeString());
            Console.WriteLine("Done.");
    
            Console.Read();
        }
    
        private static void RunTestCase( EProperty[] properties, int[] values ) 
        {         
            foreach( Entity entity in entities.Where( properties, values ) )
            {
    
            }
        }
    
        private static void RunSelfTraining<T>( Action<T[], int[]> testCase ) where T : struct
        {
            ParallelOptions parallelOptions = new ParallelOptions { MaxDegreeOfParallelism = maxThreads };
    
            for (int maskValue = 1; maskValue <= BitMask.GetMaxValue<T>(); maskValue++)
            {
                T mask = ( T ) (object)maskValue;
                T[] properties = mask.Split().ToArray();         
    
                int variations = (int) Math.Pow(maxValue - minValue + 1, properties.Length);
    
                Parallel.For(1, variations, parallelOptions, variation =>
                {
                    int[] values = GetVariation(variation, minValue, maxValue, properties.Length).ToArray();   
                    testCase.Invoke(properties, values);        
                } );
            }
        }
    
        public static IEnumerable<int> GetVariation( int index, int minValue, int maxValue, int count )
        {
            index = index - 1; 
            int range = maxValue - minValue + 1;
    
            for( int j = 0; j < count; j++ )
            {
                yield return index % range + minValue;
                index = index / range;
            }
        }
    }
    
    0 讨论(0)
提交回复
热议问题