Skip to content

transitions_collate_fn expects next_obs from TransitionsMinimal #860

@namheegordonkim

Description

@namheegordonkim

Bug description

result["next_obs"] = stack_maybe_dictobs([sample["next_obs"] for sample in batch])

While TransitionsMinimal doesn't have the field next_obs, the collate function called during data loading expects there to be next_obs, making TransitionsMinimal unusable with e.g. BC.

Steps to reproduce

import gym
import d4rl
from gymnasium.spaces import Box
import numpy as np

from imitation.algorithms.bc import BC
from imitation.data.types import TransitionsMinimal

env = gym.make("halfcheetah-expert-v2")
dataset = env.get_dataset()
transitions = TransitionsMinimal(obs=dataset["observations"], acts=dataset["actions"], infos=dataset["infos/qpos"])
observation_space = Box(env.observation_space.low, env.observation_space.high, dtype=float)
action_space = Box(env.action_space.low, env.action_space.high, dtype=float)
bc = BC(
    observation_space=observation_space,
    action_space=action_space,
    demonstrations=transitions,
    rng=np.random.default_rng(0),
)
bc.train(n_epochs=1)

Environment

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions