UnityToGymWrapper does not receive proper raycast observations

29 views Asked by At

I'm trying to train the "Sorter" unity game using stable-baseline3 PPO algo. I used the following code but found that the observations received by the UnityToGymWrapper is different than the real ones used by ml-agents. Anyone can please help me in getting all the raycast obs loaded properly to the UnityToGymWrapper.

def make_mla_sb3_env(config: LimitedConfig, **kwargs: Any) -> VecEnv:
    def handle_obs(obs, space):
        if isinstance(space, gym.spaces.Tuple):
            if len(space) == 1:
                return obs[0]
            # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple).
            return {str(i): v for i, v in enumerate(obs)}
        return obs

    def handle_obs_space(space):
        if isinstance(space, gym.spaces.Tuple):
            if len(space) == 1:
                return space[0]
            # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple).
            return gym.spaces.Dict({str(i): v for i, v in enumerate(space)})
        return space

    def create_env(env: str, worker_id: int) -> Callable[[], Env]:
        def _f() -> Env:
            engine_configuration_channel = EngineConfigurationChannel()
            engine_configuration_channel.set_configuration(config.engine_config)
            kwargs["side_channels"] = kwargs.get("side_channels", []) + [
                engine_configuration_channel
            ]
            
            unity_env = _unity_env_from_path_or_registry(
                env=env,
                registry=config.env_registry,
                worker_id=worker_id,
                base_port=config.base_port,
                seed=config.base_seed + worker_id,
                **kwargs,
            )
            new_env = UnityToGymWrapper(
                unity_env=unity_env,
                uint8_visual=config.visual_obs,
                allow_multiple_obs=config.allow_multiple_obs,
            )
            new_env =  observation_lambda_v0(new_env, handle_obs, handle_obs_space) 
            return new_env
        env = _f()
        return env 

    env_facts =  create_env(config.env_path_or_name, worker_id=1) 
    return  env_facts #SubprocVecEnv(env_facts)

app = 'Sorter'
env = make_mla_sb3_env(
    config=LimitedConfig(
        env_path_or_name=f'{app}.app',  # Can use any name from a registry or a path to your own unity build.
        base_port=6006,
        base_seed=42,
        num_env=NUM_ENVS,
        allow_multiple_obs=False, # also tried to set this to True, nothing changed
    ),
    no_graphics=False,  
)

UnityToGymWrapper obs for a single instance looks like:

[[ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
   0.0000000e+00 -0.0000000e+00  9.9999994e-01  2.7002692e+00
  -5.9604645e-06 -6.4622855e+00 -5.2537459e-01  4.2033527e-02
   9.8480719e-01  1.7365181e-01 -5.3345096e-01  1.1550533e-02
   8.6602354e-01  5.0000316e-01]]

ml-agents obs for a single instance is

 array([[ 0.        ,  1.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          -1.1345062 ,  0.26473612,  0.        ],
         [ 0.        ,  0.        ,  1.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.49297386,  1.0939786 ,  0.        ],
         [ 0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          -0.10950609,  1.2897362 ,  0.        ],
         [ 0.        ,  0.        ,  0.        ,  0.        ,  1.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.20723641, -0.7100969 ,  0.        ],
         [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           1.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          -1.084339  ,  0.58147883,  0.        ],
         [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  1.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          -1.0843391 , -0.05200641,  0.        ],
         [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  1.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          -0.4262485 ,  1.2395691 ,  0.        ],
         [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          -0.7119859 , -0.5645063 ,  0.        ],
         [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  1.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.86532694, -0.05200632,  0.        ],
         [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           1.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          -0.10950617, -0.7602638 ,  0.        ],
         [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  1.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          -0.9387484 , -0.33774394,  0.        ],
         [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  1.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.20723638,  1.2395691 ,  0.        ],
         [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          -0.4262485 , -0.7100969 ,  0.        ],
         [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  1.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          -0.7119862 ,  1.0939785 ,  0.        ],
         [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           1.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.71973634,  0.8672161 ,  0.        ],
         [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  1.        ,  0.        ,  0.        ,  0.        ,
           0.71973646, -0.33774373,  0.        ],
         [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  1.        ,  0.        ,  0.        ,
           0.8653268 ,  0.5814786 ,  0.        ],
         [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
          -0.9387487 ,  0.8672159 ,  0.        ],
         [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  1.        ,
           0.915494  ,  0.26473612,  0.        ],
         [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.4929738 , -0.5645063 ,  0.        ]], dtype=float32),
      array([ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
          7.3468129e-09,  0.0000000e+00,  1.0000000e+00, -6.2222075e-01,
          2.9802322e-08, -8.0573219e-01], dtype=float32),
     array([-0.31923848, -0.23316126, -0.64279026, -0.76604235, -0.311114  ,
         -0.24242917, -0.8660271 , -0.49999726], dtype=float32)

I tried different parameters settings e.g allow_multiple_obs=T/F didnt work

1

There are 1 answers

0
Maisa-ASM On

fixed it by making changes to miniconda3/envs/mlagents/lib/python3.10/site-packages/mlagents_envs/envs/unity_gym_env.py

 def _get_vector_obs(
        self, step_result: Union[DecisionSteps, TerminalSteps]
    ) -> np.ndarray:
        result: List[np.ndarray] = []  
        for obs in step_result.obs:
            #print('[***] step_result ', obs)   
            if len(obs.shape) == 2:
                #print('[***] step_result ', obs.shape)
                result.append(obs)
            if len(obs.shape) == 3:
                #print('shape ', np.expand_dims(obs[0].flatten(),axis=0).shape)
                result.append(np.expand_dims(obs[0].flatten(),axis=0))
        return np.concatenate(result, axis=1)
        
    def _get_vec_obs_size(self) -> int:
        result = 0
        for obs_spec in self.group_spec.observation_specs:
            #print('obs_spec in gym ', obs_spec.shape)
            if len(obs_spec.shape) == 2 : # added for the sorter game by maisa
               result += obs_spec.shape[0] * obs_spec.shape[1]
            if len(obs_spec.shape) == 1:
                result += obs_spec.shape[0]
        return result