Source code for spaic.IO.Environment

# -*- coding: utf-8 -*-
"""
Created on 2020/8/12
@project: SPAIC
@filename: Environment
@author: Hong Chaofei
@contact: hongchf@gmail.com

@description:
定义强化学习的环境交互模块
"""
from abc import ABC, abstractmethod
from .utils import RGBtoGray, GraytoBinary, reshape
# import gym
import numpy as np

'''
# Examples: initialize the environment of CartPole-v1
import gym
environment = gym.make('CartPole-v0')
for i_episode in range(20):
    observation = environment.reset()
    for t in range(100):
        environment.render()
        print(observation)
        action = environment.action_space.sample()
        observation, reward, done, info = environment.step(action)
        if done:
            print("Episode finished after {} timesteps".format(t+1))
            break
environment.close()
'''

[docs]class BaseEnvironment(ABC): """ Abstract environment class. """ def __init__(self): pass
[docs] @abstractmethod def step(self, action: int): """ Abstract method for ``step()``. Args: action (int): action to take in environment. """ pass
[docs] @abstractmethod def reset(self): """ Abstract method for ``reset()``. """ pass
[docs] @abstractmethod def render(self): """ Abstract method for ``render()``. """ pass
[docs] @abstractmethod def seed(self, seed): """ Abstract method for ``seed()``. """ pass
[docs] @abstractmethod def close(self): """ Abstract method for ``close()``. """ pass
[docs]class GymEnvironment(BaseEnvironment): """ Wrapper the OpenAI ``gym`` environments. """ def __init__(self, name: str, **kwargs): """ Initializes the environment wrapper. This class makes the assumption that the OpenAI ``gym`` environment will provide an image of format HxW as an observation. Args: name (str): The name of an OpenAI ``gym`` environment. encoding (str): The key of encoding class which is used to encode observations into spike trains. Attributes: max_prob (float): Maximum spiking probability. clip_rewards (bool): Whether or not to use ``np.sign`` of rewards. binary (bool): Whether to convert the image to binary """ import gym self.name = name self.environmet = gym.make(name) self.action_space = self.environmet.action_space self.action_num = self.action_space.n self.shape = kwargs.get('shape', None) self.binary = kwargs.get('binary', False) self.gray = kwargs.get('binary', True) self.flatten = kwargs.get('flatten', True) # Keyword arguments. self.max_prob = kwargs.get('max_prob', 1.0) self.clip_rewards = kwargs.get('clip_rewards', True) self.episode_step_count = 0 self.obs = None self.reward = None assert ( 0.0 < self.max_prob <= 1.0 ), "Maximum spiking probability must be in (0, 1]."
[docs] def step(self, action): """ Wrapper around the OpenAI ``gym`` environment ``step()`` function. Args: action (int): Action to take in the environment. Returns: Observation, reward, done flag, and information dictionary. """ # Call gym's environment step function. self.obs, self.reward, self.done, info = self.environmet.step(action) if self.clip_rewards: self.reward = np.sign(self.reward) """ After encoding the shape of 1D observations will become [Time_step, batch_size, length]. 2D observations are mono images. They will be flatten into 1D. 3D observations are color images that will be converted to grayscale images and then will be flatten into 1D. """ if len(self.obs.shape) >= 3 and self.gray: self.obs = RGBtoGray(self.obs) if self.binary: self.obs = GraytoBinary(self.obs) if self.shape is not None: if self.shape != self.obs.shape: self.obs = reshape(self.obs, self.shape) # Flatten if len(self.obs.shape) >= 2 and self.flatten: self.obs = self.obs.flatten() # Add the raw observation from the gym environment into the info for display. info['gym_obs'] = self.obs self.episode_step_count += 1 # Return converted observations and other information. return self.obs, self.reward, self.done, info
[docs] def reset(self): """ Wrapper around the OpenAI ``gym`` environment ``reset()`` function. :return: Observation from the environment. """ # Call gym's environment reset function. self.obs = self.environmet.reset() if len(self.obs.shape) >= 3 and self.gray: self.obs = RGBtoGray(self.obs) if self.binary: self.obs = GraytoBinary(self.obs) if self.shape is not None: self.shape = tuple(self.shape) if self.shape != self.obs.shape: self.obs = reshape(self.obs, self.shape) # Flatten if len(self.obs.shape) >= 2 and self.flatten: self.obs = self.obs.flatten() self.episode_step_count = 0 return self.obs
[docs] def render(self, mode): """ Wrapper around the OpenAI ``gym`` environment ``render()`` function. """ return self.environmet.render(mode)
[docs] def seed(self, seed): """ Wrapper around the OpenAI ``gym`` environment ``render()`` function. """ self.environmet.seed(seed)
[docs] def close(self): """ Wrapper around the OpenAI ``gym`` environment ``close()`` function. """ self.environmet.close()