Policy Gradient

class PolicyGradient(model, lr)[source]

Bases: parl.core.paddle.algorithm.Algorithm

__init__(model, lr)[source]

Policy gradient algorithm

Parameters:
  • model (parl.Model) – model defining forward network of policy.
  • lr (float) – learning rate.
learn(obs, action, reward)[source]

Update model with policy gradient algorithm

Parameters:
  • obs (paddle tensor) – shape of (batch_size, obs_dim)
  • action (paddle tensor) – shape of (batch_size, 1)
  • reward (paddle tensor) – shape of (batch_size, 1)
Returns:

shape of (1)

Return type:

loss (paddle tensor)

predict(obs)[source]

Predict the probability of actions

Parameters:obs (paddle tensor) – shape of (obs_dim,)
Returns:shape of (action_dim,)
Return type:prob (paddle tensor)