NEAT Algorithm on Snake-Game only makes snakes go in circles

107 views Asked by At

I'm trying to implement NEAT algorithm to a snake game i coded. However, it doesn't work, and the snakes just go around non-stop. The last thing i try, was to implement some sort of "danger or food check", to tell the snake if the food was near or not. It looks in 3 different ways by 8 blocks.

import pygame as pg
import os, neat, random

SCREEN_HEIGHT = 650
SCREEN_WIDTH = SCREEN_HEIGHT - 50
WIDTH = SCREEN_WIDTH
HEIGHT = SCREEN_HEIGHT - 50
BLOCKSIZE = 30
COLS = WIDTH // BLOCKSIZE
ROWS = HEIGHT // BLOCKSIZE

class Snake:
    def __init__(self, screen, color):
        self.screen = screen
        self.location = 
        pg.math.Vector2(COLS/2,ROWS/2)
        self.moving_to = pg.math.Vector2(1,0)
        self.length = 1
        self.positions = [self.location]
        self.alive = True
        self.hungry = 0
        self.color = color
        self.food = Food(screen, color)
        self.distanceHistory = []

    def reset(self):
        self.location = pg.math.Vector2(COLS/2,ROWS/2)
        self.length = 1
        self.alive = True
        self.positions = [self.location]
        
    def render(self):
        for position in self.positions:
            pg.draw.rect(self.screen, self.color, (position[0] * BLOCKSIZE, position[1] * BLOCKSIZE , BLOCKSIZE, BLOCKSIZE))
        
        self.food.render()
    def move(self):
        new_position = self.location + self.moving_to
    
        if self.alive:
            self.location += self.moving_to 
            #Move and insert new position to keep track of the tail
            self.positions.insert(0, new_position)
            if len(self.positions) > self.length:
                #Pop if not necessary the oldest position saved
                self.positions.pop()
    
    def threeWaystoDie(self):
        self.hungry += 1
        new_position = self.location + self.moving_to
        #If touches it's tail
        if len(self.positions) and new_position in self.positions[2:]:
            self.alive = False
        #If touches a wall
        elif not (0 <= self.location[0] <= COLS - 1  and 0 <= self.location[1] <= ROWS - 1):
            self.alive = False  
        
        if self.hungry >= 200:
            self.alive = False
        
        
    
    def checkDangerOrFood(self, blockstoCheck):
        straight = []
        right = []
        left = []
    
        #Each moving to, has 3 possible directions to check
        #If wall or tail near 
        #RIGHT
            #straight
        if (self.moving_to[0], self.moving_to[1]) == (1, 0):
            for i in range(blockstoCheck):
                new_position = self.location + self.moving_to * i
                if new_position[0] >= COLS - 1 or new_position in self.positions[2:]:
                    straight.append(-1)
                elif new_position == self.food.location:
                    straight.append(1)
                    
                else:
                    straight.append(0)
            
            #up
            for i in range(blockstoCheck):
                new_position = self.location + pg.math.Vector2(0, -1) * i
                if new_position[1] <= 0 or new_position in self.positions[2:]:
                    left.append(-1)
                    
                elif new_position == self.food.location:
                    left.append(1)
                    
                else:
                    left.append(0)
            #down
            for i in range(blockstoCheck):
                new_position = self.location + pg.math.Vector2(0, 1) * i
                if new_position[1] <= 0 or new_position in self.positions[2:]:
                    right.append(-1)
                    
                elif new_position == self.food.location:
                    right.append(1)
                    
                else:
                    right.append(0)
        #LEFT
            #straight
        if (self.moving_to[0], self.moving_to[1]) == (-1, 0):
            for i in range(blockstoCheck):
                new_position = self.location + self.moving_to * i
                if new_position[0] <= 0 or new_position in self.positions[2:]:
                    straight.append(-1)
                    
                elif new_position == self.food.location:
                    straight.append(1)
                    
                else:
                    straight.append(0)
            
            #down
            for i in range(blockstoCheck):
                new_position = self.location + pg.math.Vector2(0, 1) * i
                if new_position[1] >= ROWS - 1 or new_position in self.positions[2:]:
                    left.append(-1)
                    
                elif new_position == self.food.location:
                    left.append(1)
                    
                else:
                    left.append(0)
            #up
            for i in range(blockstoCheck):
                new_position = self.location + pg.math.Vector2(0, -1) * i
                if new_position[1] <= 0 or new_position in self.positions[2:]:
                    right.append(-1)
                    
                elif new_position == self.food.location:
                    right.append(1)
                    
                else:
                    right.append(0)
        #UP
            #straight
        if (self.moving_to[0], self.moving_to[1]) == (0, -1):
            for i in range(blockstoCheck):
                new_position = self.location + self.moving_to * i
                if new_position[1] <= 0 or new_position in self.positions[2:]:
                    straight.append(-1)
                    
                elif new_position == self.food.location:
                    straight.append(1)
                    
                else:
                    straight.append(0)
            
            #left
            for i in range(blockstoCheck):
                new_position = self.location + pg.math.Vector2(-1, 0) * i
                if new_position[0] <= 0 or new_position in self.positions[2:]:
                    left.append(-1)
                    
                elif new_position == self.food.location:
                    left.append(1)
                    
                else:
                    left.append(0)
            #right
            for i in range(blockstoCheck):
                new_position = self.location + pg.math.Vector2(1, 0) * i
                if new_position [0] >= COLS - 1 or new_position in self.positions[2:]:
                    right.append(-1)
                    
                elif new_position == self.food.location:
                    right.append(1)
                    
                else:
                    right.append(0)
        #DOWN
            #straight
        if (self.moving_to[0], self.moving_to[1]) == (0, 1):
            for i in range(blockstoCheck):
                new_position = self.location + self.moving_to * i
                if new_position[1] >= ROWS - 1 or new_position in self.positions[2:]:
                    straight.append(-1)
                    
                elif new_position == self.food.location:
                    straight.append(1)
                    
                else:
                    straight.append(0)
            
            #left
            for i in range(blockstoCheck):
                new_position = self.location + pg.math.Vector2(0, 1) * i
                if new_position[0] <= 0 or new_position in self.positions[2:]:
                    left.append(-1)
                    
                elif new_position == self.food.location:
                    left.append(1)
                    
                else:
                    left.append(0)
            #right
            for i in range(blockstoCheck):
                new_position = self.location + pg.math.Vector2(0, -1) * i
                if new_position[0] >= COLS - 1 or new_position in self.positions[2:]:
                    right.append(-1)
                    
                elif new_position == self.food.location:
                    right.append(1)
                    
                else:
                    right.append(0)
        return straight, left, right

class Food:
    def __init__(self, screen, color):
        self.screen = screen
        self.location = pg.math.Vector2(randint(0, COLS - 1), randint(0, ROWS - 1))
        self.color = color

    def render(self):
        pg.draw.circle(self.screen, self.color, (self.location[0] * BLOCKSIZE + (BLOCKSIZE//2), self.location[1] * BLOCKSIZE+ (BLOCKSIZE//2)), 10)
    
    def reset(self, positionsOfSnake):
        self.location = pg.math.Vector2(randint(0, COLS - 1), randint(0, ROWS - 1))
        #Not to spawn on the snake's tail
        while self.location in positionsOfSnake:
            self.location = pg.math.Vector2(randint(0, COLS - 1), randint(0, ROWS - 1))

GEN = 0
colors = ["red", "yellow", "purple", "white", "green", "blue"]

def draw_stuff(screen, gen, snakes):
    font = pg.font.SysFont("timesnewroman", 20, bold=True)
    label = font.render(f"Gen: {gen} Snakes Alive: {snakes}", 1, "red")

    screen.blit(label, (WIDTH / 2 - (label.get_width()/2),
                (SCREEN_HEIGHT - 30) - label.get_height()/2))


def draw_lines(screen):
    # Lines
    for x in range(COLS + 1):
        pg.draw.line(screen, "red", (x * BLOCKSIZE, 0),
                     (x * BLOCKSIZE, HEIGHT))
    for x in range(ROWS + 1):
        pg.draw.line(screen, "red", (0, x * BLOCKSIZE), (WIDTH, x * BLOCKSIZE))


def draw_grid(screen):
    # Contour
    pg.draw.rect(screen, "red", (0, 0, WIDTH, HEIGHT), 2)
    #draw_lines(screen)


def main(genomes, config):
    global HUMAN, GEN
    screen = pg.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))

    pg.init()
    clock = pg.time.Clock()

    snakes = []
    nets = []
    ge = []

    for _, g in genomes:
        net = neat.nn.FeedForwardNetwork.create(g, config)
        nets.append(net)
        color = random.choice(colors)
        snakes.append(Snake(screen, color))
        g.fitness = 0
        ge.append(g)

    score = 0

    run = True
    while run:
        pg.display.update()
        clock.tick(10)
        screen.fill((30, 30, 30))

        for event in pg.event.get():
            if event.type == pg.QUIT:
                run = False
                pg.quit()
                quit()

        draw_grid(screen)
        draw_stuff(screen, GEN, len(snakes))
        

        for x, snake in enumerate(snakes):
            #ge[x].fitness += 0.1
            snakes[x].move()
            snakes[x].render()
            snakes[x].threeWaystoDie()
            distance2Food = snakes[x].location.distance_to(snakes[x].food.location)

            blockstoCheck = 8
            dangers = snakes[x].checkDangerOrFood(blockstoCheck)
            straight = dangers[0]
            left = dangers[1]
            right = dangers[2]

            inputs = (
                    #snakes[x].location[0], 
                    #snakes[x].location[1],
                    snakes[x].moving_to[0], 
                    snakes[x].moving_to[1],
                    # snakes[x].hungry, 
                    distance2Food, 
                    snakes[x].location.angle_to(snakes[x].food.location)
                    #snakes[x].food.location[0],
                    #snakes[x].food.location[1],
                    )
            #Adding dangers to inputs (3x8)
            for i in range(blockstoCheck):
                inputs = inputs + (straight[i], left[i], right[i])

            output = nets[x].activate((inputs))

            indexOfMaxValue = output.index(max(output))
            if (snake.moving_to[0], snake.moving_to[1]) == (1, 0):
        
                if indexOfMaxValue == 0:
                    turn = snakes[x].moving_to
                if indexOfMaxValue == 1:
                    turn = pg.math.Vector2(0, -1)
                if indexOfMaxValue == 2:
                    turn = pg.math.Vector2(0, 1)

            if (snake.moving_to[0], snake.moving_to[1]) == (-1, 0):
            
                if indexOfMaxValue == 0:
                    turn = snakes[x].moving_to
                if indexOfMaxValue == 1:
                    turn = pg.math.Vector2(0, 1)
                if indexOfMaxValue == 2:
                    turn = pg.math.Vector2(0, -1)

            if (snake.moving_to[0], snake.moving_to[1]) == (0, 1):
            
                if indexOfMaxValue == 0:
                    turn = snakes[x].moving_to
                if indexOfMaxValue == 1:
                    turn = pg.math.Vector2(1, 0)
                if indexOfMaxValue == 2:
                    turn = pg.math.Vector2(-1, 0)

            if (snake.moving_to[0], snake.moving_to[1]) == (0, -1):
            
                if indexOfMaxValue == 0:
                    turn = snakes[x].moving_to
                if indexOfMaxValue == 1:
                    turn = pg.math.Vector2(-1, 0)
                if indexOfMaxValue == 2:
                    turn = pg.math.Vector2(1, 0)

            snakes[x].moving_to = turn

        
        if len(snakes) <= 0:
            run = False
            GEN += 1
            break

    
        for x, snake in enumerate(snakes):
            if snakes[x].location.distance_to(snakes[x].food.location) <= 1.0:
                snake.length += 1
                score += 10
                snakes[x].food.reset(snakes[x].positions)
                snakes[x].hungry = -snake.length 
                ge[x].fitness += 10

            #IF snake dies
            if snakes[x].alive == False:
                ge[x].fitness -= 1
                snakes.pop(x)
                nets.pop(x)
                ge.pop(x)


def run_neat(config):
    #pop = neat.Checkpointer.restore_checkpoint('neat-checkpoint-27')
    pop = neat.Population(config)
    pop.add_reporter(neat.StdOutReporter(True))
    stats = neat.StatisticsReporter()
    pop.add_reporter(stats)
    pop.add_reporter(neat.Checkpointer(10))

    winner = pop.run(main, 50)

if __name__ == "__main__":
    
    pg.display.set_caption('Snake')
    pg.font.init()

    local_dir = os.path.dirname(__file__)
    config_path = os.path.join(local_dir, "config.txt")
    config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction, 
            neat.DefaultSpeciesSet, neat.DefaultStagnation, config_path)

    run_neat(config)

Here's also the config file for the neat algorithm

[NEAT]
fitness_criterion     = mean
fitness_threshold     = 400
pop_size              = 250
reset_on_extinction   = False

[DefaultStagnation]
species_fitness_func = max
max_stagnation       = 20
species_elitism      = 2

[DefaultReproduction]
elitism            = 2
survival_threshold = 0.2

[DefaultGenome]
# node activation options
activation_default      = tanh
activation_mutate_rate  = 1.0
activation_options      = tanh

# node aggregation options
aggregation_default     = sum
aggregation_mutate_rate = 0.0
aggregation_options     = sum

# node bias options
bias_init_mean          = 3.0
bias_init_stdev         = 1.0
bias_max_value          = 30.0
bias_min_value          = -30.0
bias_mutate_power       = 0.5
bias_mutate_rate        = 0.7
bias_replace_rate       = 0.1

# genome compatibility options
compatibility_disjoint_coefficient = 1.0
compatibility_weight_coefficient   = 0.5

# connection add/remove rates
conn_add_prob           = 0.5
conn_delete_prob        = 0.5

# connection enable options
enabled_default         = True
enabled_mutate_rate     = 0.04

feed_forward            = True
initial_connection      = full_direct

# node add/remove rates
node_add_prob           = 0.2
node_delete_prob        = 0.2

# network parameters
num_hidden              = 3
num_inputs              = 28
num_outputs             = 3

# node response options
response_init_mean      = 1.0
response_init_stdev     = 0.0
response_max_value      = 30.0
response_min_value      = -30.0
response_mutate_power   = 0.0
response_mutate_rate    = 0.0
response_replace_rate   = 0.0

# connection weight options
weight_init_mean        = 0.0
weight_init_stdev       = 1.0
weight_max_value        = 30
weight_min_value        = -30
weight_mutate_power     = 0.5
weight_mutate_rate      = 0.8
weight_replace_rate     = 0.1

[DefaultSpeciesSet]
compatibility_threshold = 3.0

I would appreciate any advice u can give me, since i'm stuck as hell.

I've tried with different inputs and outputs, but nothing seems to work, which makes me think that there's an error somewhere else. Here's some of the outputs examples:

[0.9999999999999445, -0.999999999999569, 1.0]
[-0.9999999999497647, 1.0, 1.0]
[-1.0, 1.0, -1.0]
[-1.0, 1.0, 1.0]
[-1.0, 0.8234808233113396, 1.0]
[-0.9999999999991865, -0.9999999999999998, 1.0]
0

There are 0 answers