diff --git a/configs/env/default.yaml b/configs/env/default.yaml index aa09a8cb..1debdb1f 100644 --- a/configs/env/default.yaml +++ b/configs/env/default.yaml @@ -3,4 +3,4 @@ name: tsp generator_params: num_loc: 20 - loc_distribution: uniform + loc_distribution: uniform \ No newline at end of file diff --git a/configs/env/fjsp.yaml b/configs/env/fjsp/10j-5m.yaml similarity index 100% rename from configs/env/fjsp.yaml rename to configs/env/fjsp/10j-5m.yaml diff --git a/configs/env/fjsp/15j-10m.yaml b/configs/env/fjsp/15j-10m.yaml new file mode 100644 index 00000000..7df35842 --- /dev/null +++ b/configs/env/fjsp/15j-10m.yaml @@ -0,0 +1,13 @@ +_target_: rl4co.envs.FJSPEnv +name: fjsp + +generator_params: + num_jobs: 15 + num_machines: 10 + min_ops_per_job: 8 + max_ops_per_job: 12 + min_processing_time: 1 + max_processing_time: 20 + min_eligible_ma_per_op: 1 + +data_dir: ${paths.root_dir}/data/fjsp diff --git a/configs/env/fjsp/20j-10m.yaml b/configs/env/fjsp/20j-10m.yaml new file mode 100644 index 00000000..4574f22a --- /dev/null +++ b/configs/env/fjsp/20j-10m.yaml @@ -0,0 +1,13 @@ +_target_: rl4co.envs.FJSPEnv +name: fjsp + +generator_params: + num_jobs: 20 + num_machines: 10 + min_ops_per_job: 8 + max_ops_per_job: 12 + min_processing_time: 1 + max_processing_time: 20 + min_eligible_ma_per_op: 1 + +data_dir: ${paths.root_dir}/data/fjsp diff --git a/configs/env/fjsp/20j-5m.yaml b/configs/env/fjsp/20j-5m.yaml new file mode 100644 index 00000000..04e7f8d7 --- /dev/null +++ b/configs/env/fjsp/20j-5m.yaml @@ -0,0 +1,13 @@ +_target_: rl4co.envs.FJSPEnv +name: fjsp + +generator_params: + num_jobs: 20 + num_machines: 5 + min_ops_per_job: 4 + max_ops_per_job: 6 + min_processing_time: 1 + max_processing_time: 20 + min_eligible_ma_per_op: 1 + +data_dir: ${paths.root_dir}/data/fjsp diff --git a/configs/env/jssp/10j-10m.yaml b/configs/env/jssp/10j-10m.yaml new file mode 100644 index 00000000..d7e13964 --- /dev/null +++ b/configs/env/jssp/10j-10m.yaml @@ -0,0 +1,11 @@ +_target_: rl4co.envs.JSSPEnv +name: jssp + +generator_params: + num_jobs: 10 + num_machines: 10 + min_processing_time: 1 + max_processing_time: 99 + +data_dir: ${paths.root_dir}/data/jssp/taillard +test_file: ${env.generator_params.num_jobs}j_${env.generator_params.num_machines}m \ No newline at end of file diff --git a/configs/env/jssp/15j-15m.yaml b/configs/env/jssp/15j-15m.yaml new file mode 100644 index 00000000..16d20dc2 --- /dev/null +++ b/configs/env/jssp/15j-15m.yaml @@ -0,0 +1,11 @@ +_target_: rl4co.envs.JSSPEnv +name: jssp + +generator_params: + num_jobs: 15 + num_machines: 15 + min_processing_time: 1 + max_processing_time: 99 + +data_dir: ${paths.root_dir}/data/jssp/taillard +test_file: ${env.generator_params.num_jobs}j_${env.generator_params.num_machines}m \ No newline at end of file diff --git a/configs/env/jssp/20j-20m.yaml b/configs/env/jssp/20j-20m.yaml new file mode 100644 index 00000000..b3373d85 --- /dev/null +++ b/configs/env/jssp/20j-20m.yaml @@ -0,0 +1,11 @@ +_target_: rl4co.envs.JSSPEnv +name: jssp + +generator_params: + num_jobs: 20 + num_machines: 20 + min_processing_time: 1 + max_processing_time: 99 + +data_dir: ${paths.root_dir}/data/jssp/taillard +test_file: ${env.generator_params.num_jobs}j_${env.generator_params.num_machines}m \ No newline at end of file diff --git a/configs/env/jssp/6j-6m.yaml b/configs/env/jssp/6j-6m.yaml new file mode 100644 index 00000000..d7cc4170 --- /dev/null +++ b/configs/env/jssp/6j-6m.yaml @@ -0,0 +1,11 @@ +_target_: rl4co.envs.JSSPEnv +name: jssp + +generator_params: + num_jobs: 6 + num_machines: 6 + min_processing_time: 1 + max_processing_time: 99 + +data_dir: ${paths.root_dir}/data/jssp/taillard +test_file: ${env.generator_params.num_jobs}j_${env.generator_params.num_machines}m \ No newline at end of file diff --git a/configs/experiment/routing/tsp-stepwise-ppo.yaml b/configs/experiment/routing/tsp-stepwise-ppo.yaml new file mode 100644 index 00000000..678223ec --- /dev/null +++ b/configs/experiment/routing/tsp-stepwise-ppo.yaml @@ -0,0 +1,57 @@ +# @package _global_ + +defaults: + - override /model: l2d.yaml + - override /callbacks: default.yaml + - override /trainer: default.yaml + - override /logger: wandb.yaml + +env: + _target_: rl4co.envs.TSPEnv4PPO + generator_params: + num_loc: 20 + +logger: + wandb: + project: "rl4co" + tags: ["am-stepwise-ppo", "${env.name}"] + group: ${env.name}${env.generator_params.num_loc} + name: ppo-${env.name}${env.generator_params.num_loc} + +trainer: + max_epochs: 10 + precision: 32-true + +embed_dim: 256 +num_heads: 8 +model: + _target_: rl4co.models.StepwisePPO + policy: + _target_: rl4co.models.L2DPolicy4PPO + decoder: + _target_: rl4co.models.zoo.l2d.decoder.L2DDecoder + env_name: ${env.name} + embed_dim: ${embed_dim} + feature_extractor: + _target_: rl4co.models.zoo.am.encoder.AttentionModelEncoder + embed_dim: ${embed_dim} + num_heads: ${num_heads} + num_layers: 4 + normalization: "batch" + env_name: "tsp" + actor: + _target_: rl4co.models.zoo.l2d.decoder.AttnActor + embed_dim: ${embed_dim} + num_heads: ${num_heads} + env_name: ${env.name} + embed_dim: ${embed_dim} + env_name: ${env.name} + het_emb: False + batch_size: 512 + mini_batch_size: 512 + train_data_size: 20000 + val_data_size: 1_000 + test_data_size: 1_000 + reward_scale: scale + optimizer_kwargs: + lr: 1e-4 \ No newline at end of file diff --git a/configs/experiment/scheduling/am-pomo.yaml b/configs/experiment/scheduling/am-pomo.yaml new file mode 100644 index 00000000..a3d2cde7 --- /dev/null +++ b/configs/experiment/scheduling/am-pomo.yaml @@ -0,0 +1,23 @@ +# @package _global_ + +defaults: + - scheduling/base + +logger: + wandb: + tags: ["am-pomo", "${env.name}"] + name: "am-pomo-${env.name}-${env.generator_params.num_jobs}j-${env.generator_params.num_machines}m" + +model: + _target_: rl4co.models.POMO + policy: + _target_: rl4co.models.L2DAttnPolicy + env_name: ${env.name} + scaling_factor: ${scaling_factor} + batch_size: 64 + num_starts: 10 + num_augment: 0 + baseline: "shared" + metrics: + val: ["reward", "max_reward"] + test: ${model.metrics.val} diff --git a/configs/experiment/scheduling/am-ppo.yaml b/configs/experiment/scheduling/am-ppo.yaml new file mode 100644 index 00000000..c5d38eb1 --- /dev/null +++ b/configs/experiment/scheduling/am-ppo.yaml @@ -0,0 +1,56 @@ +# @package _global_ + +defaults: + - scheduling/base + +logger: + wandb: + tags: ["am-ppo", "${env.name}"] + name: "am-ppo-${env.name}-${env.generator_params.num_jobs}j-${env.generator_params.num_machines}m" + +embed_dim: 256 +num_heads: 8 + +model: + _target_: rl4co.models.StepwisePPO + policy: + _target_: rl4co.models.L2DPolicy4PPO + decoder: + _target_: rl4co.models.zoo.l2d.decoder.L2DDecoder + env_name: ${env.name} + embed_dim: ${embed_dim} + feature_extractor: + _target_: rl4co.models.zoo.matnet.matnet_w_sa.Encoder + embed_dim: ${embed_dim} + num_heads: ${num_heads} + num_layers: 4 + normalization: "batch" + init_embedding: + _target_: rl4co.models.nn.env_embeddings.init.FJSPMatNetInitEmbedding + embed_dim: ${embed_dim} + scaling_factor: ${scaling_factor} + actor: + _target_: rl4co.models.zoo.l2d.decoder.L2DAttnActor + embed_dim: ${embed_dim} + num_heads: ${num_heads} + env_name: ${env.name} + scaling_factor: ${scaling_factor} + stepwise: True + env_name: ${env.name} + embed_dim: ${embed_dim} + scaling_factor: ${scaling_factor} + het_emb: True + batch_size: 128 + val_batch_size: 512 + test_batch_size: 64 + # Song et al use 1000 iterations over batches of 20 = 20_000 + # We train 10 epochs on a set of 2000 instance = 20_000 + train_data_size: 2000 + mini_batch_size: 512 + reward_scale: scale + optimizer_kwargs: + lr: 1e-4 + +env: + stepwise_reward: True + _torchrl_mode: True \ No newline at end of file diff --git a/configs/experiment/scheduling/base.yaml b/configs/experiment/scheduling/base.yaml new file mode 100644 index 00000000..e84f95fd --- /dev/null +++ b/configs/experiment/scheduling/base.yaml @@ -0,0 +1,38 @@ +# @package _global_ + +defaults: + - override /model: l2d.yaml + - override /callbacks: default.yaml + - override /trainer: default.yaml + - override /logger: wandb.yaml + +logger: + wandb: + project: "rl4co" + log_model: "all" + group: "${env.name}-${env.generator_params.num_jobs}-${env.generator_params.num_machines}" + tags: ??? + name: ??? + +trainer: + max_epochs: 10 + # NOTE for some reason l2d is extremely sensitive to precision + # ONLY USE 32-true for l2d! + precision: 32-true + +seed: 12345678 + +scaling_factor: 20 + +model: + _target_: ??? + batch_size: ??? + train_data_size: 2_000 + val_data_size: 1_000 + test_data_size: 1_000 + optimizer_kwargs: + lr: 1e-4 + weight_decay: 1e-6 + lr_scheduler: "ExponentialLR" + lr_scheduler_kwargs: + gamma: 0.95 diff --git a/configs/experiment/scheduling/gnn-ppo.yaml b/configs/experiment/scheduling/gnn-ppo.yaml new file mode 100644 index 00000000..d9c04856 --- /dev/null +++ b/configs/experiment/scheduling/gnn-ppo.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +defaults: + - scheduling/base + +logger: + wandb: + tags: ["gnn-ppo", "${env.name}"] + name: "gnn-ppo-${env.name}-${env.generator_params.num_jobs}j-${env.generator_params.num_machines}m" + +# params from Song et al. +model: + _target_: rl4co.models.L2DPPOModel + policy_kwargs: + embed_dim: 128 + num_encoder_layers: 3 + scaling_factor: ${scaling_factor} + max_grad_norm: 1 + ppo_epochs: 3 + het_emb: False + batch_size: 128 + val_batch_size: 512 + test_batch_size: 64 + mini_batch_size: 512 + reward_scale: scale + optimizer_kwargs: + lr: 1e-4 + +trainer: + max_epochs: 10 + + +env: + stepwise_reward: True + _torchrl_mode: True \ No newline at end of file diff --git a/configs/experiment/scheduling/hgnn-pomo.yaml b/configs/experiment/scheduling/hgnn-pomo.yaml new file mode 100644 index 00000000..eb688c03 --- /dev/null +++ b/configs/experiment/scheduling/hgnn-pomo.yaml @@ -0,0 +1,27 @@ +# @package _global_ + +defaults: + - scheduling/base + +logger: + wandb: + tags: ["hgnn-pomo", "${env.name}"] + name: "hgnn-pomo-${env.name}-${env.generator_params.num_jobs}j-${env.generator_params.num_machines}m" + +model: + _target_: rl4co.models.POMO + policy: + _target_: rl4co.models.L2DPolicy + env_name: ${env.name} + embed_dim: 256 + num_encoder_layers: 3 + stepwise_encoding: False + scaling_factor: ${scaling_factor} + het_emb: True + num_starts: 10 + batch_size: 64 + num_augment: 0 + baseline: "shared" + metrics: + val: ["reward", "max_reward"] + test: ${model.metrics.val} \ No newline at end of file diff --git a/configs/experiment/scheduling/hgnn-ppo.yaml b/configs/experiment/scheduling/hgnn-ppo.yaml new file mode 100644 index 00000000..8e3a62d8 --- /dev/null +++ b/configs/experiment/scheduling/hgnn-ppo.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +defaults: + - scheduling/base + +logger: + wandb: + tags: ["hgnn-ppo", "${env.name}"] + name: "hgnn-ppo-${env.name}-${env.generator_params.num_jobs}j-${env.generator_params.num_machines}m" + +# params from Song et al. +model: + _target_: rl4co.models.L2DPPOModel + policy_kwargs: + embed_dim: 128 + num_encoder_layers: 3 + scaling_factor: ${scaling_factor} + max_grad_norm: 1 + ppo_epochs: 3 + het_emb: True + batch_size: 128 + val_batch_size: 512 + test_batch_size: 64 + mini_batch_size: 512 + reward_scale: scale + optimizer_kwargs: + lr: 1e-4 + +trainer: + max_epochs: 10 + + +env: + stepwise_reward: True + _torchrl_mode: True \ No newline at end of file diff --git a/configs/experiment/scheduling/matnet-pomo.yaml b/configs/experiment/scheduling/matnet-pomo.yaml new file mode 100644 index 00000000..bab68644 --- /dev/null +++ b/configs/experiment/scheduling/matnet-pomo.yaml @@ -0,0 +1,38 @@ +# @package _global_ + +defaults: + - scheduling/base + +logger: + wandb: + tags: ["matnet-pomo", "${env.name}"] + name: "matnet-pomo-${env.name}-${env.generator_params.num_jobs}j-${env.generator_params.num_machines}m" + +embed_dim: 256 + +model: + _target_: rl4co.models.POMO + policy: + _target_: rl4co.models.L2DPolicy + encoder: + _target_: rl4co.models.zoo.matnet.matnet_w_sa.Encoder + embed_dim: ${embed_dim} + num_heads: 8 + num_layers: 4 + normalization: "batch" + init_embedding: + _target_: rl4co.models.nn.env_embeddings.init.FJSPMatNetInitEmbedding + embed_dim: ${embed_dim} + scaling_factor: ${scaling_factor} + env_name: ${env.name} + embed_dim: ${embed_dim} + stepwise_encoding: False + het_emb: True + scaling_factor: ${scaling_factor} + batch_size: 64 + num_starts: 10 + num_augment: 0 + baseline: "shared" + metrics: + val: ["reward", "max_reward"] + test: ${model.metrics.val} diff --git a/configs/experiment/scheduling/matnet-ppo.yaml b/configs/experiment/scheduling/matnet-ppo.yaml new file mode 100644 index 00000000..f0e30e3b --- /dev/null +++ b/configs/experiment/scheduling/matnet-ppo.yaml @@ -0,0 +1,48 @@ +# @package _global_ + +defaults: + - scheduling/base + +logger: + wandb: + tags: ["matnet-ppo", "${env.name}"] + name: "matnet-ppo-${env.name}-${env.generator_params.num_jobs}j-${env.generator_params.num_machines}m" + +embed_dim: 256 + +model: + _target_: rl4co.models.StepwisePPO + policy: + _target_: rl4co.models.L2DPolicy4PPO + decoder: + _target_: rl4co.models.zoo.l2d.decoder.L2DDecoder + env_name: ${env.name} + embed_dim: ${embed_dim} + het_emb: True + feature_extractor: + _target_: rl4co.models.zoo.matnet.matnet_w_sa.Encoder + embed_dim: ${embed_dim} + num_heads: 8 + num_layers: 4 + normalization: "batch" + init_embedding: + _target_: rl4co.models.nn.env_embeddings.init.FJSPMatNetInitEmbedding + embed_dim: ${embed_dim} + scaling_factor: ${scaling_factor} + env_name: ${env.name} + embed_dim: ${embed_dim} + scaling_factor: ${scaling_factor} + het_emb: True + batch_size: 128 + val_batch_size: 512 + test_batch_size: 64 + # Song et al use 1000 iterations over batches of 20 = 20_000 + # We train 10 epochs on a set of 2000 instance = 20_000 + mini_batch_size: 512 + reward_scale: scale + optimizer_kwargs: + lr: 1e-4 + +env: + stepwise_reward: True + _torchrl_mode: True \ No newline at end of file diff --git a/configs/model/hetgnn.yaml b/configs/model/hetgnn.yaml deleted file mode 100644 index 600ef49f..00000000 --- a/configs/model/hetgnn.yaml +++ /dev/null @@ -1,3 +0,0 @@ -_target_: rl4co.models.HetGNNModel - -baseline: "rollout" \ No newline at end of file diff --git a/configs/model/l2d.yaml b/configs/model/l2d.yaml new file mode 100644 index 00000000..9b93f1ee --- /dev/null +++ b/configs/model/l2d.yaml @@ -0,0 +1 @@ +_target_: rl4co.models.L2DModel \ No newline at end of file diff --git a/configs/model/matnet.yaml b/configs/model/matnet.yaml index 543de5af..6634e4a1 100644 --- a/configs/model/matnet.yaml +++ b/configs/model/matnet.yaml @@ -2,6 +2,6 @@ _target_: rl4co.models.MatNet metrics: train: ["loss", "reward", "max_reward"] - val: ["max_reward"] + val: ["reward", "max_reward"] test: ["max_reward"] log_on_step: True \ No newline at end of file diff --git a/examples/other/2-scheduling.ipynb b/examples/other/2-scheduling.ipynb index ee5de19f..4c4c029e 100644 --- a/examples/other/2-scheduling.ipynb +++ b/examples/other/2-scheduling.ipynb @@ -13,9 +13,18 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", @@ -29,16 +38,16 @@ "import networkx as nx\n", "import matplotlib.pyplot as plt\n", "from rl4co.envs import FJSPEnv\n", - "from rl4co.models.zoo.hetgnn import HetGNNModel\n", - "from rl4co.models.zoo.hetgnn.policy import HetGNNPolicy\n", - "from rl4co.models.zoo.hetgnn.encoder import HetGNNEncoder\n", - "from rl4co.models.zoo.hetgnn.decoder import HetGNNDecoder\n", + "from rl4co.models.zoo.l2d import L2DModel\n", + "from rl4co.models.zoo.l2d.policy import L2DPolicy\n", + "from rl4co.models.zoo.l2d.decoder import L2DDecoder\n", + "from rl4co.models.nn.graph.hgnn import HetGNNEncoder\n", "from rl4co.utils.trainer import RL4COTrainer" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -56,7 +65,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -81,12 +90,12 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -169,7 +178,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -192,7 +201,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -220,16 +229,16 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[ 0, 4, 9, 14, 20, 25, 29, 33, 37, 42]])" + "tensor([[ 0, 5, 10, 16, 20, 24, 29, 34, 40, 44]])" ] }, - "execution_count": 7, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -241,7 +250,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -253,7 +262,7 @@ } ], "source": [ - "decoder = HetGNNDecoder(embed_dim=32)\n", + "decoder = L2DDecoder(env_name=env.name, embed_dim=32)\n", "logits, mask = decoder(td, (ma_emb, op_emb), num_starts=0)\n", "# (1 + num_jobs * num_machines)\n", "print(logits.shape)" @@ -261,7 +270,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -290,7 +299,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -304,7 +313,7 @@ }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -349,9 +358,83 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.\n", + "/Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.\n", + "/Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/accelerator_connector.py:551: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.\n", + "Using bfloat16 Automatic Mixed Precision (AMP)\n", + "GPU available: False, used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "/Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n", + "Missing logger folder: /Users/luttmann/Documents/Diss/Repos/nco/ai4co/rl4co/examples/other/lightning_logs\n", + "val_file not set. Generating dataset instead\n", + "test_file not set. Generating dataset instead\n", + "\n", + " | Name | Type | Params\n", + "--------------------------------------------\n", + "0 | env | FJSPEnv | 0 \n", + "1 | policy | L2DPolicy | 15.9 K\n", + "2 | baseline | WarmupBaseline | 15.9 K\n", + "--------------------------------------------\n", + "31.9 K Trainable params\n", + "0 Non-trainable params\n", + "31.9 K Total params\n", + "0.127 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c543880423f84865a05170d16a5aa6fd", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00 TensorDict: # Sample penalty penalty = self.penalty_sampler.sample((*batch_size, self.num_loc)) - # Sampel prize + # Take uniform prizes + # Now expectation is 0.5 so expected total prize is n / 2, we want to force to visit approximately half of the nodes + # so the constraint will be that total prize >= (n / 2) / 2 = n / 4 + # equivalently, we divide all prizes by n / 4 and the total prize should be >= 1 deterministic_prize = self.deterministic_prize_sampler.sample( (*batch_size, self.num_loc) ) + + # In the deterministic setting, the stochastic_prize is not used and the deterministic prize is known + # In the stochastic setting, the deterministic prize is the expected prize and is known up front but the + # stochastic prize is only revealed once the node is visited + # Stochastic prize is between (0, 2 * expected_prize) such that E(stochastic prize) = E(deterministic_prize) stochastic_prize = self.stochastic_prize_sampler.sample( (*batch_size, self.num_loc) - ) + ) * deterministic_prize return TensorDict( { diff --git a/rl4co/envs/routing/tsp/env.py b/rl4co/envs/routing/tsp/env.py index eb1c3402..0ea92569 100644 --- a/rl4co/envs/routing/tsp/env.py +++ b/rl4co/envs/routing/tsp/env.py @@ -11,7 +11,7 @@ ) from rl4co.envs.common.base import RL4COEnvBase -from rl4co.utils.ops import gather_by_index, get_tour_length +from rl4co.utils.ops import gather_by_index, get_distance, get_tour_length from rl4co.utils.pylogger import get_pylogger from .generator import TSPGenerator @@ -167,5 +167,62 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor): ).all(), "Invalid tour" @staticmethod - def render(td: TensorDict, actions: torch.Tensor=None, ax = None): + def render(td: TensorDict, actions: torch.Tensor = None, ax=None): return render(td, actions, ax) + + +class DenseRewardTSPEnv(TSPEnv): + """ + This is an experimental version of the TSPEnv to be used with stepwise PPO. That is + this environment defines a stepwise reward function for the TSP which is the distance added + to the current tour by the given action. + """ + + def __init__( + self, generator: TSPGenerator = None, generator_params: dict = {}, **kwargs + ): + super().__init__( + generator, + generator_params, + check_solution=False, + _torchrl_mode=True, + **kwargs, + ) + + def _step(self, td): + last_node = td["current_node"].clone() + current_node = td["action"] + + first_node = current_node if td["i"].all() == 0 else td["first_node"] + + # # Set not visited to 0 (i.e., we visited the node) + available = td["action_mask"].scatter( + -1, current_node.unsqueeze(-1).expand_as(td["action_mask"]), 0 + ) + + # We are done there are no unvisited locations + done = torch.sum(available, dim=-1) == 0 + + # calc stepwise reward + last_node_loc = gather_by_index(td["locs"], last_node) + curr_node_loc = gather_by_index(td["locs"], current_node) + reward = get_distance(last_node_loc, curr_node_loc)[:, None] + + td.update( + { + "first_node": first_node, + "current_node": current_node, + "i": td["i"] + 1, + "action_mask": available, + "reward": reward, + "done": done, + }, + ) + return td + + def _get_reward(self, td, actions=None) -> TensorDict: + if actions is not None: + # Gather locations in order of tour and return distance between them (i.e., -reward) + locs_ordered = gather_by_index(td["locs"], actions) + return -get_tour_length(locs_ordered) + return -td["reward"] diff --git a/rl4co/envs/scheduling/__init__.py b/rl4co/envs/scheduling/__init__.py index 897ee755..40b5571e 100644 --- a/rl4co/envs/scheduling/__init__.py +++ b/rl4co/envs/scheduling/__init__.py @@ -1,3 +1,5 @@ from rl4co.envs.scheduling.ffsp.env import FFSPEnv from rl4co.envs.scheduling.fjsp.env import FJSPEnv + +# from rl4co.envs.scheduling.jssp.env import JSSPEnv from rl4co.envs.scheduling.smtwtp.env import SMTWTPEnv diff --git a/rl4co/envs/scheduling/ffsp/env.py b/rl4co/envs/scheduling/ffsp/env.py index f3037e18..da8eff24 100644 --- a/rl4co/envs/scheduling/ffsp/env.py +++ b/rl4co/envs/scheduling/ffsp/env.py @@ -13,11 +13,10 @@ UnboundedDiscreteTensorSpec, ) -from .generator import FFSPGenerator -from .render import render - from rl4co.envs.common.base import RL4COEnvBase +from .generator import FFSPGenerator + class FFSPEnv(RL4COEnvBase): """Flexible Flow Shop Problem (FFSP) environment. @@ -58,7 +57,7 @@ def __init__( generator_params: dict = {}, **kwargs, ): - super().__init__(**kwargs) + super().__init__(check_solution=False, **kwargs) if generator is None: generator = FFSPGenerator(**generator_params) self.generator = generator @@ -283,9 +282,7 @@ def _reset( # Init index record tensor time_idx = torch.zeros(size=(*batch_size,), dtype=torch.long, device=device) - sub_time_idx = torch.zeros( - size=(*batch_size,), dtype=torch.long, device=device - ) + sub_time_idx = torch.zeros(size=(*batch_size,), dtype=torch.long, device=device) # Scheduling status information schedule = torch.full( @@ -412,7 +409,10 @@ def _make_spec(self, generator: FFSPGenerator): dtype=torch.int64, ), job_duration=UnboundedDiscreteTensorSpec( - shape=(generator.num_job + 1, generator.num_machine * generator.num_stage), + shape=( + generator.num_job + 1, + generator.num_machine * generator.num_stage, + ), dtype=torch.int64, ), shape=(), diff --git a/rl4co/envs/scheduling/fjsp/env.py b/rl4co/envs/scheduling/fjsp/env.py index 4a6a217f..dac1c8b6 100644 --- a/rl4co/envs/scheduling/fjsp/env.py +++ b/rl4co/envs/scheduling/fjsp/env.py @@ -68,6 +68,8 @@ def __init__( generator: FJSPGenerator = None, generator_params: dict = {}, mask_no_ops: bool = True, + check_mask: bool = False, + stepwise_reward: bool = False, **kwargs, ): super().__init__(check_solution=False, **kwargs) @@ -81,6 +83,8 @@ def __init__( self.num_jobs = generator.num_jobs self.n_ops_max = generator.max_ops_per_job * self.num_jobs self.mask_no_ops = mask_no_ops + self.check_mask = check_mask + self.stepwise_reward = stepwise_reward self._make_spec(self.generator) def _decode_graph_structure(self, td: TensorDict): @@ -173,13 +177,14 @@ def _reset(self, td: TensorDict = None, batch_size=None) -> TensorDict: }, ) - td_reset.set("lbs", calc_lower_bound(td_reset)) - td_reset.set("is_ready", op_is_ready(td_reset)) td_reset.set("action_mask", self.get_action_mask(td_reset)) + # add additional features to tensordict + td_reset["lbs"] = calc_lower_bound(td_reset) + td_reset = self._get_features(td_reset) return td_reset - def get_action_mask(self, td: TensorDict) -> torch.Tensor: + def _get_job_machine_availability(self, td: TensorDict): batch_size = td.size(0) # (bs, jobs, machines) @@ -200,16 +205,33 @@ def get_action_mask(self, td: TensorDict) -> torch.Tensor: td["proc_times"], td["next_op"].unsqueeze(1), dim=2, squeeze=False ).transpose(1, 2) action_mask.add_(next_ops_proc_times == 0) + return action_mask + + def get_action_mask(self, td: TensorDict) -> torch.Tensor: + # 1 indicates machine or job is unavailable at current time step + action_mask = self._get_job_machine_availability(td) if self.mask_no_ops: - no_op_mask = ~td["done"] + # masking is only allowed if instance is finished + no_op_mask = td["done"] else: - no_op_mask = ~td["job_in_process"].any(1, keepdims=True) & ~td["done"] + # if no job is currently processed and instance is not finished yet, waiting is not allowed + no_op_mask = ( + td["job_in_process"].any(1, keepdims=True) & (~td["done"]) + ) | td["done"] # flatten action mask to correspond with logit shape action_mask = rearrange(action_mask, "bs j m -> bs (j m)") # NOTE: 1 means feasible action, 0 means infeasible action - mask = torch.cat((~no_op_mask, ~action_mask), dim=1) + mask = torch.cat((no_op_mask, ~action_mask), dim=1) + return mask + def _translate_action(self, td): + """This function translates an action into a machine, job tuple.""" + selected_job = td["action"] // self.num_mas + selected_op = td["next_op"].gather(1, selected_job[:, None]).squeeze(1) + selected_machine = td["action"] % self.num_mas + return selected_job, selected_op, selected_machine + def _step(self, td: TensorDict): # cloning required to avoid inplace operation which avoids gradient backtracking td = td.clone() @@ -225,14 +247,11 @@ def _step(self, td: TensorDict): if no_op.any(): td, dones = self._transit_to_next_time(no_op, td) + # select only instances that perform a scheduling action td_op = td.masked_select(req_op) - # (#req_op) - selected_job = td_op["action"] // self.num_mas - # (#req_op) - selected_machine = td_op["action"] % self.num_mas - td_op = self._make_step(td_op, selected_job, selected_machine) - + td_op = self._make_step(td_op) + # update the tensordict td[req_op] = td_op # action mask @@ -243,11 +262,26 @@ def _step(self, td: TensorDict): td, dones = self._transit_to_next_time(step_complete, td) td.set("action_mask", self.get_action_mask(td)) step_complete = self._check_step_complete(td, dones) + if self.check_mask: + assert reduce(td["action_mask"], "bs ... -> bs", "any").all() + + if self.stepwise_reward: + # if we require a stepwise reward, the change in the calculated lower bounds could serve as such + lbs = calc_lower_bound(td) + td["reward"] = -(lbs.max(1).values - td["lbs"].max(1).values) + td["lbs"] = lbs + else: + td["lbs"] = calc_lower_bound(td) + + # add additional features to tensordict + td = self._get_features(td) + + return td + def _get_features(self, td): # after we have transitioned to a next time step, we determine which operations are ready td["is_ready"] = op_is_ready(td) - - td["lbs"] = calc_lower_bound(td) + # td["lbs"] = calc_lower_bound(td) return td @@ -259,17 +293,18 @@ def _check_step_complete(td, dones): """ return ~reduce(td["action_mask"], "bs ... -> bs", "any") & ~dones - def _make_step(self, td: TensorDict, selected_job, selected_machine) -> TensorDict: + def _make_step(self, td: TensorDict) -> TensorDict: """ Environment transition function """ batch_idx = torch.arange(td.size(0)) - td["job_in_process"][batch_idx, selected_job] = 1 + # 3*(#req_op) + selected_job, selected_op, selected_machine = self._translate_action(td) - # (#req_op) - selected_op = td["next_op"].gather(1, selected_job[:, None]).squeeze(1) + # mark job as being processed + td["job_in_process"][batch_idx, selected_job] = 1 # mark op as schedules td["op_scheduled"][batch_idx, selected_op] = True @@ -285,6 +320,23 @@ def _make_step(self, td: TensorDict, selected_job, selected_machine) -> TensorDi td["ma_assignment"][batch_idx, selected_machine, selected_op] = 1 # update the state of the selected machine td["busy_until"][batch_idx, selected_machine] = td["time"] + proc_time_of_action + # update adjacency matrices (remove edges) + td["proc_times"] = td["proc_times"].scatter( + 2, + selected_op[:, None, None].expand(-1, self.num_mas, 1), + torch.zeros_like(td["proc_times"]), + ) + td["ops_ma_adj"] = td["proc_times"].contiguous().gt(0).to(torch.float32) + td["num_eligible"] = torch.sum(td["ops_ma_adj"], dim=1) + # update the positions of an operation in the job (subtract 1 from each operation of the selected job) + td["ops_sequence_order"] = ( + td["ops_sequence_order"] - gather_by_index(td["job_ops_adj"], selected_job, 1) + ).clip(0) + # some checks + assert torch.allclose( + td["proc_times"].sum(1).gt(0).sum(1), # num ops with eligible machine + (~(td["op_scheduled"] + td["pad_mask"])).sum(1), # num unscheduled ops + ) return td @@ -333,7 +385,15 @@ def _transit_to_next_time(self, step_complete, td: TensorDict) -> TensorDict: return td, td["done"].squeeze(1) def _get_reward(self, td, actions=None) -> TensorDict: - return -td["finish_times"].masked_fill(td["pad_mask"], -torch.inf).max(1).values + if self.stepwise_reward and actions is None: + return td["reward"] + else: + assert td[ + "done" + ].all(), "Set stepwise_reward to True if you want reward prior to completion" + return ( + -td["finish_times"].masked_fill(td["pad_mask"], -torch.inf).max(1).values + ) def _make_spec(self, generator: FJSPGenerator): self.observation_spec = CompositeSpec( @@ -422,3 +482,8 @@ def select_start_nodes(self, td: TensorDict, num_starts: int): def get_num_starts(self, td): # NOTE in the paper they use N_s = 100 return 100 + + @staticmethod + def load_data(fpath, batch_size=[]): + g = FJSPFileGenerator(fpath) + return g(batch_size=batch_size) diff --git a/rl4co/envs/scheduling/fjsp/generator.py b/rl4co/envs/scheduling/fjsp/generator.py index f1ae6202..60246a50 100644 --- a/rl4co/envs/scheduling/fjsp/generator.py +++ b/rl4co/envs/scheduling/fjsp/generator.py @@ -198,6 +198,8 @@ def _generate(self, batch_size: List[int]) -> TensorDict: end_idx = self.start_idx + batch_size td = self.td[self.start_idx : end_idx] self.start_idx += batch_size + if self.start_idx >= self.num_samples: + self.start_idx = 0 return td @staticmethod @@ -210,7 +212,4 @@ def list_files(path): if os.path.isfile(os.path.join(path, f)) ] assert len(files) > 0 - files = sorted( - files, key=lambda f: int(os.path.splitext(os.path.basename(f))[0][:4]) - ) return files diff --git a/rl4co/envs/scheduling/fjsp/utils.py b/rl4co/envs/scheduling/fjsp/utils.py index b3ee40b8..f870e8b6 100644 --- a/rl4co/envs/scheduling/fjsp/utils.py +++ b/rl4co/envs/scheduling/fjsp/utils.py @@ -152,8 +152,8 @@ def first_diff(x: Tensor, dim: int): shape = x.shape shape = (*shape[:dim], 1, *shape[dim + 1 :]) seq_cutoff = x.index_select(dim, torch.arange(x.size(dim) - 1, device=x.device)) - lagged_seq = x - torch.cat((seq_cutoff.new_zeros(*shape), seq_cutoff), dim=dim) - return lagged_seq + first_diff_seq = x - torch.cat((seq_cutoff.new_zeros(*shape), seq_cutoff), dim=dim) + return first_diff_seq def spatial_encoding(td: TensorDict): @@ -205,7 +205,7 @@ def calc_lower_bound(td: TensorDict): We detect this offset by detecting ops-machine pairs, where the first possible start point of the operation is before the machine becomes idle again - Therefore, we add this discrepancy to the proc_time of the respective ops-ma combination - 2.) If an operation has been scheduled, we use its real finishing time as lower bound. In this case, using the cumulative sum + 2.) If an operation has been scheduled, we use its actual finishing time as lower bound. In this case, using the cumulative sum of all peedecessors of a job does not make sense, since it is likely to differ from the real finishing time of its direct predecessor (its only a lower bound). Therefore, we add the finish time to the cumulative sum of processing time of all UNSCHEDULED operations, to obtain the lower bound. @@ -213,8 +213,6 @@ def calc_lower_bound(td: TensorDict): add them to the matrix of processing times, where already processed operations are masked (with zero) - :param TensorDict td: _description_ - :return _type_: _description_ """ proc_times = td["proc_times"].clone() # (bs, ma, ops) @@ -231,12 +229,15 @@ def calc_lower_bound(td: TensorDict): maybe_start_at = torch.bmm(ops_adj[..., 0], finish_times[..., None]).squeeze(2) # using the start_time, we can determine if and how long an op needs to wait for a machine to finish wait_for_ma_offset = torch.clip(busy_until[..., None] - maybe_start_at[:, None], 0) - # we add this required waiting time to the respective processing time - after that we determine the best machine for each operation - mask = proc_times == 0 - proc_times[mask] = torch.inf - proc_times += wait_for_ma_offset - # select best machine for operation, given the offset - min_proc_times = proc_times.min(1).values + # we add this required waiting time to the respective processing time + proc_time_plus_wait = torch.where( + proc_times == 0, proc_times, proc_times + wait_for_ma_offset + ) + # NOTE get the mean processing time over all eligible machines for lb calulation + # ops_proc_times = torch.where(proc_times == 0, torch.inf, proc_time_plus_wait).min(1).values) + ops_proc_times = proc_time_plus_wait.sum(1) / (proc_times.gt(0).sum(1) + 1e-9) + # mask proc times for already scheduled ops + ops_proc_times[op_scheduled.to(torch.bool)] = 0 ############### REGARDING POINT 2 OF DOCSTRING ################### # Now we determine all operations that are not scheduled yet (and thus have no finish_time). We will compute the cumulative @@ -257,7 +258,7 @@ def calc_lower_bound(td: TensorDict): # masking the processing time of scheduled operations and add their finish times instead (first diff thereof) lb_end_expand = ( - proc_matrix_not_scheduled * min_proc_times.unsqueeze(1).expand_as(job_ops_adj) + proc_matrix_not_scheduled * ops_proc_times.unsqueeze(1).expand_as(job_ops_adj) + finish_times_1st_diff ) # (bs, max_ops); lower bound finish time per operation using the cumsum logic diff --git a/rl4co/envs/scheduling/jssp/env.py b/rl4co/envs/scheduling/jssp/env.py new file mode 100644 index 00000000..702ceda7 --- /dev/null +++ b/rl4co/envs/scheduling/jssp/env.py @@ -0,0 +1,123 @@ +import torch + +from einops import einsum, reduce +from tensordict import TensorDict +from torch._tensor import Tensor + +from rl4co.envs import FJSPEnv +from rl4co.utils.ops import gather_by_index + +from .generator import JSSPFileGenerator, JSSPGenerator + + +class JSSPEnv(FJSPEnv): + """Job-Shop Scheduling Problem (JSSP) environment + At each step, the agent chooses a job. The operation to be processed next for the selected job is + then executed on the associated machine. The reward is 0 unless the agent scheduled all operations of all jobs. + In that case, the reward is (-)makespan of the schedule: maximizing the reward is equivalent to minimizing the makespan. + NOTE: The JSSP is a special case of the FJSP, when the number of eligible machines per operation is equal to one for all + operations. Therefore, this environment is a subclass of the FJSP environment. + Observations: + - time: current time + - next_op: next operation per job + - proc_times: processing time of operation-machine pairs + - pad_mask: specifies padded operations + - start_op_per_job: id of first operation per job + - end_op_per_job: id of last operation per job + - start_times: start time of operation (defaults to 0 if not scheduled) + - finish_times: finish time of operation (defaults to INIT_FINISH if not scheduled) + - job_ops_adj: adjacency matrix specifying job-operation affiliation + - ops_job_map: same as above but using ids of jobs to indicate affiliation + - ops_sequence_order: specifies the order in which operations have to be processed + - ma_assignment: specifies which operation has been scheduled on which machine + - busy_until: specifies until when the machine will be busy + - num_eligible: number of machines that can process an operation + - job_in_process: whether job is currently being processed + - job_done: whether the job is done + + Constrains: + the agent may not select: + - jobs that are done already + - jobs that are currently processed + + Finish condition: + - the agent has scheduled all operations of all jobs + + Reward: + - the negative makespan of the final schedule + + Args: + generator: JSSPGenerator instance as the data generator + generator_params: parameters for the generator + mask_no_ops: if True, agent may not select waiting operation (unless instance is done) + """ + + name = "jssp" + + def __init__( + self, + generator: JSSPGenerator = None, + generator_params: dict = {}, + mask_no_ops: bool = True, + **kwargs, + ): + if generator is None: + if generator_params.get("file_path", None) is not None: + generator = JSSPFileGenerator(**generator_params) + else: + generator = JSSPGenerator(**generator_params) + + super().__init__(generator, generator_params, mask_no_ops, **kwargs) + + def _get_features(self, td): + td = super()._get_features(td) + # get the id of the machine that executes an operation: + # (bs, ops, ma) + ops_ma_adj = td["ops_ma_adj"].transpose(1, 2) + # (bs, jobs, ma) + ma_of_next_op = gather_by_index(ops_ma_adj, td["next_op"], dim=1) + # (bs, jobs) + td["next_ma"] = ma_of_next_op.argmax(-1) + + # adjacency matrix specifying neighbors of an operation, including its + # predecessor and successor operations and operations on the same machine + ops_on_same_ma_adj = einsum( + td["ops_ma_adj"], td["ops_ma_adj"], "b m o1, b m o2 -> b o1 o2 " + ) + # concat pred, succ and ops on same machine + adj = torch.cat((td["ops_adj"], ops_on_same_ma_adj.unsqueeze(-1)), dim=-1).sum(-1) + # mask padded operations and those scheduled + mask = td["pad_mask"] + td["op_scheduled"] + adj.masked_fill_(mask.unsqueeze(1), 0) + td["adjacency"] = adj + + return td + + def get_action_mask(self, td: TensorDict) -> Tensor: + action_mask = self._get_job_machine_availability(td) + if self.mask_no_ops: + # masking is only allowed if instance is finished + no_op_mask = td["done"] + else: + # if no job is currently processed and instance is not finished yet, waiting is not allowed + no_op_mask = ( + td["job_in_process"].any(1, keepdims=True) & (~td["done"]) + ) | td["done"] + # reduce action mask to correspond with logit shape + action_mask = reduce(action_mask, "bs j m -> bs j", reduction="all") + # NOTE: 1 means feasible action, 0 means infeasible action + # (bs, 1 + n_j) + mask = torch.cat((no_op_mask, ~action_mask), dim=1) + return mask + + def _translate_action(self, td): + job = td["action"] + op = gather_by_index(td["next_op"], job, dim=1) + # get the machine that corresponds to the selected operation + ma = gather_by_index(td["ops_ma_adj"], op.unsqueeze(1), dim=2).nonzero()[:, 1] + return job, op, ma + + @staticmethod + def load_data(fpath, batch_size=[]): + g = JSSPFileGenerator(fpath) + return g(batch_size=batch_size) diff --git a/rl4co/envs/scheduling/jssp/generator.py b/rl4co/envs/scheduling/jssp/generator.py new file mode 100644 index 00000000..bc9f1fc6 --- /dev/null +++ b/rl4co/envs/scheduling/jssp/generator.py @@ -0,0 +1,208 @@ +import os + +from functools import partial +from typing import List + +import numpy as np +import torch + +from tensordict.tensordict import TensorDict +from torch.nn.functional import one_hot + +from rl4co.envs.common.utils import Generator +from rl4co.utils.pylogger import get_pylogger + +from .parser import get_max_ops_from_files, read + +log = get_pylogger(__name__) + + +class JSSPGenerator(Generator): + + """Data generator for the Job-Shop Scheduling Problem (JSSP) + + Args: + num_stage: number of stages + num_machine: number of machines + num_job: number of jobs + min_time: minimum running time of each job on each machine + max_time: maximum running time of each job on each machine + flatten_stages: whether to flatten the stages + one2one_ma_map: whether each machine should have exactly one operation per job (common in jssp benchmark instances) + + Returns: + A TensorDict with the following key: + start_op_per_job [batch_size, num_jobs]: first operation of each job + end_op_per_job [batch_size, num_jobs]: last operation of each job + proc_times [batch_size, num_machines, total_n_ops]: processing time of ops on machines + pad_mask [batch_size, total_n_ops]: not all instances have the same number of ops, so padding is used + + """ + + def __init__( + self, + num_jobs: int = 6, + num_machines: int = 6, + min_ops_per_job: int = None, + max_ops_per_job: int = None, + min_processing_time: int = 1, + max_processing_time: int = 99, + one2one_ma_map: bool = True, + **unused_kwargs, + ): + self.num_jobs = num_jobs + self.num_mas = num_machines + # quite common in jssp to have as many ops per job as there are machines + self.min_ops_per_job = min_ops_per_job or self.num_mas + self.max_ops_per_job = max_ops_per_job or self.num_mas + self.min_processing_time = min_processing_time + self.max_processing_time = max_processing_time + self.one2one_ma_map = one2one_ma_map + if self.one2one_ma_map: + assert self.min_ops_per_job == self.max_ops_per_job == self.num_mas + + # determines whether to use a fixed number of total operations or let it vary between instances + # NOTE: due to the way rl4co builds datasets, we need a fixed size here + self.n_ops_max = self.max_ops_per_job * self.num_jobs + + # FFSP environment doen't have any other kwargs + if len(unused_kwargs) > 0: + log.error(f"Found {len(unused_kwargs)} unused kwargs: {unused_kwargs}") + + def _simulate_processing_times(self, bs, n_ops_max) -> torch.Tensor: + if self.one2one_ma_map: + ops_machine_ids = ( + torch.rand((*bs, self.num_jobs, self.num_mas)) + .argsort(dim=-1) + .flatten(1, 2) + ) + else: + ops_machine_ids = torch.randint( + low=0, + high=self.num_mas, + size=(*bs, n_ops_max), + ) + ops_machine_adj = one_hot(ops_machine_ids, num_classes=self.num_mas) + + # (bs, max_ops, machines) + proc_times = torch.ones((*bs, n_ops_max, self.num_mas)) + proc_times = torch.randint( + self.min_processing_time, + self.max_processing_time + 1, + size=(*bs, self.num_mas, n_ops_max), + ) + + # remove proc_times for which there is no corresponding ma-ops connection + proc_times = proc_times * ops_machine_adj.transpose(1, 2) + # in JSSP there is only one machine capable to process an operation + assert (proc_times > 0).sum(1).eq(1).all() + return proc_times.to(torch.float32) + + def _generate(self, batch_size) -> TensorDict: + # simulate how many operations each job has + n_ope_per_job = torch.randint( + self.min_ops_per_job, + self.max_ops_per_job + 1, + size=(*batch_size, self.num_jobs), + ) + + # determine the total number of operations per batch instance (which may differ) + n_ops_batch = n_ope_per_job.sum(1) # (bs) + # determine the maximum total number of operations over all batch instances + n_ops_max = self.n_ops_max or n_ops_batch.max() + + # generate a mask, specifying which operations are padded + pad_mask = torch.arange(n_ops_max).unsqueeze(0).expand(*batch_size, -1) + pad_mask = pad_mask.ge(n_ops_batch[:, None].expand_as(pad_mask)) + + # determine the id of the end operation for each job + end_op_per_job = n_ope_per_job.cumsum(1) - 1 + + # determine the id of the starting operation for each job + # (bs, num_jobs) + start_op_per_job = torch.cat( + ( + torch.zeros((*batch_size, 1)).to(end_op_per_job), + end_op_per_job[:, :-1] + 1, + ), + dim=1, + ) + + # simulate processing times for machine-operation pairs + # (bs, num_mas, n_ops_max) + proc_times = self._simulate_processing_times(batch_size, n_ops_max) + + td = TensorDict( + { + "start_op_per_job": start_op_per_job, + "end_op_per_job": end_op_per_job, + "proc_times": proc_times, + "pad_mask": pad_mask, + }, + batch_size=batch_size, + ) + + return td + + +class JSSPFileGenerator(Generator): + """Data generator for the Job-Shop Scheduling Problem (JSSP) using instance files + + Args: + path: path to files + + Returns: + A TensorDict with the following key: + start_op_per_job [batch_size, num_jobs]: first operation of each job + end_op_per_job [batch_size, num_jobs]: last operation of each job + proc_times [batch_size, num_machines, total_n_ops]: processing time of ops on machines + pad_mask [batch_size, total_n_ops]: not all instances have the same number of ops, so padding is used + + """ + + def __init__(self, file_path: str, n_ops_max: int = None, **unused_kwargs): + self.files = ( + [file_path] if os.path.isfile(file_path) else self.list_files(file_path) + ) + self.num_samples = len(self.files) + + if len(unused_kwargs) > 0: + log.error(f"Found {len(unused_kwargs)} unused kwargs: {unused_kwargs}") + + if len(self.files) > 1: + n_ops_max = get_max_ops_from_files(self.files) + + ret = map(partial(read, max_ops=n_ops_max), self.files) + + td_list, num_jobs, num_machines, max_ops_per_job = list(zip(*list(ret))) + num_jobs, num_machines = map(lambda x: x[0], (num_jobs, num_machines)) + max_ops_per_job = max(max_ops_per_job) + + self.td = torch.cat(td_list, dim=0) + self.num_mas = num_machines + self.num_jobs = num_jobs + self.max_ops_per_job = max_ops_per_job + self.start_idx = 0 + + def _generate(self, batch_size: List[int]) -> TensorDict: + batch_size = np.prod(batch_size) + if batch_size > self.num_samples: + log.warning( + f"Only found {self.num_samples} instance files, but specified dataset size is {batch_size}" + ) + end_idx = self.start_idx + batch_size + td = self.td[self.start_idx : end_idx] + self.start_idx += batch_size + if self.start_idx >= self.num_samples: + self.start_idx = 0 + return td + + @staticmethod + def list_files(path): + files = [ + os.path.join(path, f) + for f in os.listdir(path) + if os.path.isfile(os.path.join(path, f)) + ] + assert len(files) > 0, "No files found in the specified path" + return files diff --git a/rl4co/envs/scheduling/jssp/parser.py b/rl4co/envs/scheduling/jssp/parser.py new file mode 100644 index 00000000..9fcdb4bf --- /dev/null +++ b/rl4co/envs/scheduling/jssp/parser.py @@ -0,0 +1,110 @@ +from pathlib import Path +from typing import List, Tuple, Union + +import torch + +from tensordict import TensorDict + +ProcessingData = List[Tuple[int, int]] + + +def parse_job_line(line: Tuple[int]) -> Tuple[ProcessingData]: + """ + Parses a JSSP job data line of the following form: + + * ( ) + + In words, a line consist of n_ops pairs of values, where the first value is the + machine identifier and the second value is the processing time of the corresponding + operation-machine combination + + Note that the machine indices start from 1, so we subtract 1 to make them + zero-based. + """ + + operations = [] + i = 0 + + while i < len(line): + machine = int(line[i]) + duration = int(line[i + 1]) + operations.append((machine, duration)) + i += 2 + + return operations + + +def get_n_ops_of_instance(file): + lines = file2lines(file) + jobs = [parse_job_line(line) for line in lines[1:]] + n_ope_per_job = torch.Tensor([len(x) for x in jobs]).unsqueeze(0) + total_ops = int(n_ope_per_job.sum()) + return total_ops + + +def get_max_ops_from_files(files): + return max(map(get_n_ops_of_instance, files)) + + +def read(loc: Path, max_ops=None): + """ + Reads an JSSP instance. + + Args: + loc: location of instance file + max_ops: optionally specify the maximum number of total operations (will be filled by padding) + + Returns: + instance: the parsed instance + """ + lines = file2lines(loc) + + # First line contains metadata. + num_jobs, num_machines = lines[0][0], lines[0][1] + + # The remaining lines contain the job-operation data, where each line + # represents a job and its operations. + jobs = [parse_job_line(line) for line in lines[1:]] + n_ope_per_job = torch.Tensor([len(x) for x in jobs]).unsqueeze(0) + total_ops = int(n_ope_per_job.sum()) + if max_ops is not None: + assert total_ops <= max_ops, "got more operations then specified through max_ops" + max_ops = max_ops or total_ops + max_ops_per_job = int(n_ope_per_job.max()) + + end_op_per_job = n_ope_per_job.cumsum(1) - 1 + start_op_per_job = torch.cat((torch.zeros((1, 1)), end_op_per_job[:, :-1] + 1), dim=1) + + pad_mask = torch.arange(max_ops) + pad_mask = pad_mask.ge(total_ops).unsqueeze(0) + + proc_times = torch.zeros((num_machines, max_ops)) + op_cnt = 0 + for job in jobs: + for ma, dur in job: + # subtract one to let indices start from zero + proc_times[ma - 1, op_cnt] = dur + op_cnt += 1 + proc_times = proc_times.unsqueeze(0) + + td = TensorDict( + { + "start_op_per_job": start_op_per_job, + "end_op_per_job": end_op_per_job, + "proc_times": proc_times, + "pad_mask": pad_mask, + }, + batch_size=[1], + ) + + return td, num_jobs, num_machines, max_ops_per_job + + +def file2lines(loc: Union[Path, str]) -> List[List[int]]: + with open(loc, "r") as fh: + lines = [line for line in fh.readlines() if line.strip()] + + def parse_num(word: str): + return int(word) if "." not in word else int(float(word)) + + return [[parse_num(x) for x in line.split()] for line in lines] diff --git a/rl4co/models/__init__.py b/rl4co/models/__init__.py index 5b81cd3e..9ee741e4 100644 --- a/rl4co/models/__init__.py +++ b/rl4co/models/__init__.py @@ -14,12 +14,12 @@ NonAutoregressivePolicy, ) from rl4co.models.common.transductive import TransductiveModel +from rl4co.models.rl import StepwisePPO from rl4co.models.rl.a2c.a2c import A2C from rl4co.models.rl.common.base import RL4COLitModule from rl4co.models.rl.ppo.ppo import PPO from rl4co.models.rl.reinforce.baselines import REINFORCEBaseline, get_reinforce_baseline from rl4co.models.rl.reinforce.reinforce import REINFORCE -from rl4co.models.zoo import HetGNNModel from rl4co.models.zoo.active_search import ActiveSearch from rl4co.models.zoo.am import AttentionModel, AttentionModelPolicy from rl4co.models.zoo.amppo import AMPPO @@ -29,12 +29,19 @@ HeterogeneousAttentionModel, HeterogeneousAttentionModelPolicy, ) +from rl4co.models.zoo.l2d import ( + L2DAttnPolicy, + L2DModel, + L2DPolicy, + L2DPolicy4PPO, + L2DPPOModel, +) from rl4co.models.zoo.matnet import MatNet, MatNetPolicy from rl4co.models.zoo.mdam import MDAM, MDAMPolicy +from rl4co.models.zoo.mvmoe import MVMoE_AM, MVMoE_POMO from rl4co.models.zoo.n2s import N2S, N2SPolicy from rl4co.models.zoo.nargnn import NARGNNPolicy from rl4co.models.zoo.polynet import PolyNet from rl4co.models.zoo.pomo import POMO from rl4co.models.zoo.ptrnet import PointerNetwork, PointerNetworkPolicy from rl4co.models.zoo.symnco import SymNCO, SymNCOPolicy -from rl4co.models.zoo.mvmoe import MVMoE_POMO, MVMoE_AM diff --git a/rl4co/models/nn/attention.py b/rl4co/models/nn/attention.py index 6e4330b0..f7742103 100644 --- a/rl4co/models/nn/attention.py +++ b/rl4co/models/nn/attention.py @@ -2,7 +2,7 @@ import math import warnings -from typing import Callable, Optional +from typing import Callable, Optional, Union import torch import torch.nn as nn @@ -108,21 +108,113 @@ def __init__( self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) - def forward(self, x, key_padding_mask=None): + def forward(self, x, attn_mask=None): """x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) - key_padding_mask: bool tensor of shape (batch, seqlen) + attn_mask: bool tensor of shape (batch, seqlen) """ # Project query, key, value q, k, v = rearrange( self.Wqkv(x), "b s (three h d) -> three b h s d", three=3, h=self.num_heads ).unbind(dim=0) + if attn_mask is not None: + attn_mask = ( + attn_mask.unsqueeze(1) + if attn_mask.ndim == 3 + else attn_mask.unsqueeze(1).unsqueeze(2) + ) + # Scaled dot product attention out = self.sdpa_fn( q, k, v, - attn_mask=key_padding_mask, + attn_mask=attn_mask, + dropout_p=self.attention_dropout, + ) + return self.out_proj(rearrange(out, "b h s d -> b s (h d)")) + + +def sdpa_fn_wrapper(q, k, v, attn_mask=None, dmat=None, dropout_p=0.0, is_causal=False): + if dmat is not None: + log.warning( + "Edge weights passed to simple attention-fn, which is not supported. Weights will be ignored..." + ) + return scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal + ) + + +class MultiHeadCrossAttention(nn.Module): + """PyTorch native implementation of Flash Multi-Head Cross Attention with automatic mixed precision support. + Uses PyTorch's native `scaled_dot_product_attention` implementation, available from 2.0 + + Note: + If `scaled_dot_product_attention` is not available, use custom implementation of `scaled_dot_product_attention` without Flash Attention. + + Args: + embed_dim: total dimension of the model + num_heads: number of heads + bias: whether to use bias + attention_dropout: dropout rate for attention weights + device: torch device + dtype: torch dtype + sdpa_fn: scaled dot product attention function (SDPA) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = False, + attention_dropout: float = 0.0, + device: str = None, + dtype: torch.dtype = None, + sdpa_fn: Optional[Union[Callable, nn.Module]] = None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.embed_dim = embed_dim + self.attention_dropout = attention_dropout + + # Default to `scaled_dot_product_attention` if `sdpa_fn` is not provided + if sdpa_fn is None: + sdpa_fn = sdpa_fn_wrapper + self.sdpa_fn = sdpa_fn + + self.num_heads = num_heads + assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" + self.head_dim = self.embed_dim // num_heads + assert ( + self.head_dim % 8 == 0 and self.head_dim <= 128 + ), "Only support head_dim <= 128 and divisible by 8" + + self.Wq = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) + self.Wkv = nn.Linear(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) + + def forward(self, q_input, kv_input, cross_attn_mask=None, dmat=None): + # Project query, key, value + q = rearrange( + self.Wq(q_input), "b m (h d) -> b h m d", h=self.num_heads + ) # [b, h, m, d] + k, v = rearrange( + self.Wkv(kv_input), "b n (two h d) -> two b h n d", two=2, h=self.num_heads + ).unbind( + dim=0 + ) # [b, h, n, d] + + if cross_attn_mask is not None: + # add head dim + cross_attn_mask = cross_attn_mask.unsqueeze(1) + + # Scaled dot product attention + out = self.sdpa_fn( + q, + k, + v, + attn_mask=cross_attn_mask, + dmat=dmat, dropout_p=self.attention_dropout, ) return self.out_proj(rearrange(out, "b h s d -> b s (h d)")) @@ -250,11 +342,15 @@ def __init__( sdpa_fn: Optional[Callable] = None, moe_kwargs: Optional[dict] = None, ): - super(PointerAttnMoE, self).__init__(embed_dim, num_heads, mask_inner, out_bias, check_nan, sdpa_fn) + super(PointerAttnMoE, self).__init__( + embed_dim, num_heads, mask_inner, out_bias, check_nan, sdpa_fn + ) self.moe_kwargs = moe_kwargs self.project_out = None - self.project_out_moe = MoE(embed_dim, embed_dim, num_neurons=[], out_bias=out_bias, **moe_kwargs) + self.project_out_moe = MoE( + embed_dim, embed_dim, num_neurons=[], out_bias=out_bias, **moe_kwargs + ) if self.moe_kwargs["light_version"]: self.dense_or_moe = nn.Linear(embed_dim, 2, bias=False) self.project_out = nn.Linear(embed_dim, embed_dim, bias=out_bias) @@ -262,9 +358,16 @@ def __init__( def _project_out(self, out): """Implementation of Hierarchical Gating based on Zhou et al. (2024) .""" if self.moe_kwargs["light_version"]: - probs = F.softmax(self.dense_or_moe(out.view(-1, out.size(-1)).mean(dim=0, keepdim=True)), dim=-1) + probs = F.softmax( + self.dense_or_moe(out.view(-1, out.size(-1)).mean(dim=0, keepdim=True)), + dim=-1, + ) selected = probs.multinomial(1).squeeze(0) - out = self.project_out_moe(out) if selected.item() == 1 else self.project_out(out) + out = ( + self.project_out_moe(out) + if selected.item() == 1 + else self.project_out(out) + ) glimpse = out * probs.squeeze(0)[selected] else: glimpse = self.project_out_moe(out) diff --git a/rl4co/models/nn/env_embeddings/context.py b/rl4co/models/nn/env_embeddings/context.py index 8c651663..1fd63db8 100644 --- a/rl4co/models/nn/env_embeddings/context.py +++ b/rl4co/models/nn/env_embeddings/context.py @@ -321,6 +321,19 @@ def forward(self, embeddings, td): return self.project_context(cur_node_embedding) +class SchedulingContext(nn.Module): + def __init__(self, embed_dim: int, scaling_factor: int = 1000): + super().__init__() + self.scaling_factor = scaling_factor + self.proj_busy = nn.Linear(1, embed_dim, bias=False) + + def forward(self, h, td): + busy_for = (td["busy_until"] - td["time"].unsqueeze(1)) / self.scaling_factor + busy_proj = self.proj_busy(busy_for.unsqueeze(-1)) + # (b m e) + return h + busy_proj + + class MTVRPContext(VRPContext): """Context embedding for Multi-Task VRPEnv. Project the following to the embedding space: @@ -338,10 +351,22 @@ def __init__(self, embed_dim): ) def _state_embedding(self, embeddings, td): - remaining_linehaul_capacity = td["vehicle_capacity"] - td["used_capacity_linehaul"] - remaining_backhaul_capacity = td["vehicle_capacity"] - td["used_capacity_backhaul"] + remaining_linehaul_capacity = ( + td["vehicle_capacity"] - td["used_capacity_linehaul"] + ) + remaining_backhaul_capacity = ( + td["vehicle_capacity"] - td["used_capacity_backhaul"] + ) current_time = td["current_time"] current_route_length = td["current_route_length"] open_route = td["open_route"] - return torch.cat([remaining_linehaul_capacity, remaining_backhaul_capacity, current_time, - current_route_length, open_route], -1) + return torch.cat( + [ + remaining_linehaul_capacity, + remaining_backhaul_capacity, + current_time, + current_route_length, + open_route, + ], + -1, + ) diff --git a/rl4co/models/nn/env_embeddings/dynamic.py b/rl4co/models/nn/env_embeddings/dynamic.py index 9d9db535..470af835 100644 --- a/rl4co/models/nn/env_embeddings/dynamic.py +++ b/rl4co/models/nn/env_embeddings/dynamic.py @@ -1,5 +1,7 @@ +import torch import torch.nn as nn +from rl4co.utils.ops import gather_by_index from rl4co.utils.pylogger import get_pylogger log = get_pylogger(__name__) @@ -30,6 +32,8 @@ def env_dynamic_embedding(env_name: str, config: dict) -> nn.Module: "pdp": StaticEmbedding, "mtsp": StaticEmbedding, "smtwtp": StaticEmbedding, + "jssp": JSSPDynamicEmbedding, + "fjsp": JSSPDynamicEmbedding, "mtvrp": StaticEmbedding, } @@ -72,3 +76,46 @@ def forward(self, td): demands_with_depot ).chunk(3, dim=-1) return glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic + + +class JSSPDynamicEmbedding(nn.Module): + def __init__(self, embed_dim, linear_bias=False, scaling_factor: int = 1000) -> None: + super().__init__() + self.embed_dim = embed_dim + self.project_node_step = nn.Linear(2, 3 * embed_dim, bias=linear_bias) + self.project_edge_step = nn.Linear(1, 3, bias=linear_bias) + self.scaling_factor = scaling_factor + + def forward(self, td, cache): + ma_emb = cache.node_embeddings["machine_embeddings"] + bs, _, emb_dim = ma_emb.shape + num_jobs = td["next_op"].size(1) + # updates + updates = ma_emb.new_zeros((bs, num_jobs, 3 * emb_dim)) + + lbs = torch.clip(td["lbs"] - td["time"][:, None], 0) / self.scaling_factor + update_feat = torch.stack((lbs, td["is_ready"]), dim=-1) + job_update_feat = gather_by_index(update_feat, td["next_op"], dim=1) + updates = updates + self.project_node_step(job_update_feat) + + ma_busy = td["busy_until"] > td["time"][:, None] + # mask machines currently busy + masked_proc_times = td["proc_times"].clone() / self.scaling_factor + # bs, ma, ops + masked_proc_times[ma_busy] = 0.0 + # bs, ops, ma, 3 + edge_feat = self.project_edge_step(masked_proc_times.unsqueeze(-1)).transpose( + 1, 2 + ) + job_edge_feat = gather_by_index(edge_feat, td["next_op"], dim=1) + # bs, nodes, 3*emb + edge_upd = torch.einsum("ijkl,ikm->ijlm", job_edge_feat, ma_emb).view( + bs, num_jobs, 3 * emb_dim + ) + updates = updates + edge_upd + + # (bs, nodes, emb) + glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic = updates.chunk( + 3, dim=-1 + ) + return glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic diff --git a/rl4co/models/nn/env_embeddings/init.py b/rl4co/models/nn/env_embeddings/init.py index 3b094b84..063c1550 100644 --- a/rl4co/models/nn/env_embeddings/init.py +++ b/rl4co/models/nn/env_embeddings/init.py @@ -33,7 +33,8 @@ def env_init_embedding(env_name: str, config: dict) -> nn.Module: "mtsp": MTSPInitEmbedding, "smtwtp": SMTWTPInitEmbedding, "mdcpdp": MDCPDPInitEmbedding, - "fjsp": FJSPFeatureEmbedding, + "fjsp": FJSPInitEmbedding, + "jssp": FJSPInitEmbedding, "mtvrp": MTVRPInitEmbedding, } @@ -63,7 +64,7 @@ def forward(self, td): class MatNetInitEmbedding(nn.Module): """ - Preparing the initial row and column embeddings for FFSP. + Preparing the initial row and column embeddings for MatNet. Reference: https://github.com/yd-kwon/MatNet/blob/782698b60979effe2e7b61283cca155b7cdb727f/ATSP/ATSP_MatNet/ATSPModel.py#L51 @@ -98,7 +99,7 @@ def forward(self, td: TensorDict): col_emb[b_idx, n_idx, rand_idx] = 1.0 elif self.mode == "Random": - col_emb = torch.rand(b, r, self.embed_dim, device=dmat.device) + col_emb = torch.rand(b, c, self.embed_dim, device=dmat.device) else: raise NotImplementedError @@ -386,66 +387,115 @@ def forward(self, td): return torch.cat([depot_embeddings, pick_embeddings, delivery_embeddings], -2) -class FJSPFeatureEmbedding(nn.Module): - def __init__(self, embed_dim, linear_bias=True, norm_coef: int = 100): - super().__init__() +class JSSPInitEmbedding(nn.Module): + def __init__( + self, + embed_dim, + linear_bias: bool = True, + scaling_factor: int = 1000, + num_op_feats=5, + ): + super(JSSPInitEmbedding, self).__init__() self.embed_dim = embed_dim - self.norm_coef = norm_coef + self.scaling_factor = scaling_factor + self.init_ops_embed = nn.Linear(num_op_feats, embed_dim, linear_bias) + self.pos_encoder = PositionalEncoding(embed_dim, dropout=0.0) + + def _op_features(self, td): + proc_times = td["proc_times"] + mean_durations = proc_times.sum(1) / (proc_times.gt(0).sum(1) + 1e-9) + feats = [ + mean_durations / self.scaling_factor, + td["is_ready"], + td["num_eligible"], + td["ops_job_map"], + td["op_scheduled"], + ] + return torch.stack(feats, dim=-1) - self.init_ope_embed = nn.Linear(4, self.embed_dim, bias=False) - self.edge_embed = nn.Linear(1, embed_dim, bias=False) + def _init_ops_embed(self, td: TensorDict): + ops_feat = self._op_features(td) + ops_emb = self.init_ops_embed(ops_feat) + ops_emb = self.pos_encoder(ops_emb, td["ops_sequence_order"]) - self.ope_pos_enc = PositionalEncoding(embed_dim) - # TODO allow for reencoding after each step - self.stepwise = False + # zero out padded and finished ops + mask = td["pad_mask"] # NOTE dont mask scheduled - leads to instable training + ops_emb[mask.unsqueeze(-1).expand_as(ops_emb)] = 0 + return ops_emb - def forward(self, td: TensorDict): - if self.stepwise: - ops_emb = self._stepwise_operations_embed(td) - ma_emb = self._stepwise_machine_embed(td) - edge_emb = None - else: - ops_emb = self._init_operations_embed(td) - ma_emb = self._init_machine_embed(td) - edge_emb = self._init_edge_embed(td) - return ma_emb, ops_emb, edge_emb - - def _init_operations_embed(self, td: TensorDict): - pos = td["ops_sequence_order"] - - features = [ - td["lbs"].unsqueeze(-1) / self.norm_coef, - td["is_ready"].unsqueeze(-1), - td["num_eligible"].unsqueeze(-1), - td["ops_job_map"].unsqueeze(-1), + def forward(self, td): + return self._init_ops_embed(td) + + +class FJSPInitEmbedding(JSSPInitEmbedding): + def __init__(self, embed_dim, linear_bias=False, scaling_factor: int = 100): + super().__init__(embed_dim, linear_bias, scaling_factor, num_op_feats=5) + self.init_ma_embed = nn.Linear(1, self.embed_dim, bias=linear_bias) + self.edge_embed = nn.Linear(1, embed_dim, bias=linear_bias) + + def _op_features(self, td): + feats = [ + td["lbs"] / self.scaling_factor, + td["is_ready"], + td["num_eligible"], + td["op_scheduled"], + td["ops_job_map"], ] - features = torch.cat(features, dim=-1) - # (bs, num_ops, emb_dim) - ops_embeddings = self.init_ope_embed(features) + return torch.stack(feats, dim=-1) - # (bs, num_ops, emb_dim) - ops_embeddings = self.ope_pos_enc(ops_embeddings, pos.to(torch.int64)) - # zero out padded entries - ops_embeddings[td["pad_mask"].unsqueeze(-1).expand_as(ops_embeddings)] = 0 - return ops_embeddings + def forward(self, td: TensorDict): + ops_emb = self._init_ops_embed(td) + ma_emb = self._init_machine_embed(td) + edge_emb = self._init_edge_embed(td) + # get edges between operations and machines + # (bs, ops, ma) + edges = td["ops_ma_adj"].transpose(1, 2) + return ops_emb, ma_emb, edge_emb, edges + + def _init_edge_embed(self, td: TensorDict): + proc_times = td["proc_times"].transpose(1, 2) / self.scaling_factor + edge_embed = self.edge_embed(proc_times.unsqueeze(-1)) + return edge_embed def _init_machine_embed(self, td: TensorDict): - bs, num_ma = td["busy_until"].shape - ma_embeddings = torch.zeros( - (bs, num_ma, self.embed_dim), device=td.device, dtype=torch.float32 - ) + busy_for = (td["busy_until"] - td["time"].unsqueeze(1)) / self.scaling_factor + ma_embeddings = self.init_ma_embed(busy_for.unsqueeze(2)) return ma_embeddings - def _init_edge_embed(self, td: TensorDict): - proc_times = td["proc_times"].unsqueeze(-1) / self.norm_coef - edge_embed = self.edge_embed(proc_times) - return edge_embed - def _stepwise_operations_embed(self, td: TensorDict): - raise NotImplementedError("Stepwise encoding not yet implemented") +class FJSPMatNetInitEmbedding(JSSPInitEmbedding): + def __init__( + self, + embed_dim, + linear_bias: bool = False, + scaling_factor: int = 1000, + ): + super().__init__(embed_dim, linear_bias, scaling_factor, num_op_feats=5) + self.init_ma_embed = nn.Linear(1, self.embed_dim, bias=linear_bias) + + def _op_features(self, td): + feats = [ + td["lbs"] / self.scaling_factor, + td["is_ready"], + td["op_scheduled"], + td["num_eligible"], + td["ops_job_map"], + ] + return torch.stack(feats, dim=-1) + + def _init_machine_embed(self, td: TensorDict): + busy_for = (td["busy_until"] - td["time"].unsqueeze(1)) / self.scaling_factor + ma_embeddings = self.init_ma_embed(busy_for.unsqueeze(2)) + return ma_embeddings - def _stepwise_machine_embed(self, td: TensorDict): - raise NotImplementedError("Stepwise encoding not yet implemented") + def forward(self, td: TensorDict): + proc_times = td["proc_times"] + ops_emb = self._init_ops_embed(td) + # encoding machines + ma_emb = self._init_machine_embed(td) + # edgeweights for matnet + matnet_edge_weights = proc_times.transpose(1, 2) / self.scaling_factor + return ops_emb, ma_emb, matnet_edge_weights class MTVRPInitEmbedding(VRPInitEmbedding): @@ -455,16 +505,26 @@ def __init__(self, embed_dim, linear_bias=True, node_dim: int = 7): def forward(self, td): depot, cities = td["locs"][:, :1, :], td["locs"][:, 1:, :] - demand_linehaul, demand_backhaul = td["demand_linehaul"][..., 1:], td["demand_backhaul"][..., 1:] + demand_linehaul, demand_backhaul = ( + td["demand_linehaul"][..., 1:], + td["demand_backhaul"][..., 1:], + ) service_time = td["service_time"][..., 1:] time_windows = td["time_windows"][..., 1:, :] # [!] convert [0, inf] -> [0, 0] if a problem does not include the time window constraint, do not modify in-place - time_windows = torch.nan_to_num(time_windows, posinf=0.0) + time_windows = torch.nan_to_num(time_windows, posinf=0.0) # embeddings depot_embedding = self.init_embed_depot(depot) node_embeddings = self.init_embed( torch.cat( - (cities, demand_linehaul[..., None], demand_backhaul[..., None], time_windows, service_time[..., None]), -1 + ( + cities, + demand_linehaul[..., None], + demand_backhaul[..., None], + time_windows, + service_time[..., None], + ), + -1, ) ) return torch.cat((depot_embedding, node_embeddings), -2) diff --git a/rl4co/models/nn/graph/gcn.py b/rl4co/models/nn/graph/gcn.py index 5f0ba34e..348c6b21 100644 --- a/rl4co/models/nn/graph/gcn.py +++ b/rl4co/models/nn/graph/gcn.py @@ -1,13 +1,15 @@ -from typing import Tuple, Union +from typing import Callable, Tuple, Union import torch.nn as nn import torch.nn.functional as F from tensordict import TensorDict from torch import Tensor -from torch_geometric.data import Batch, Data -from torch_geometric.nn import GCNConv +try: + from torch_geometric.nn import GCNConv +except ImportError: + GCNConv = None from rl4co.models.nn.env_embeddings import env_init_embedding from rl4co.utils.ops import get_full_graph_edge_index from rl4co.utils.pylogger import get_pylogger @@ -15,8 +17,17 @@ log = get_pylogger(__name__) +EdgeIndexFnSignature = Callable[[TensorDict, int], Tensor] + + +def edge_idx_fn_wrapper(td: TensorDict, num_nodes: int): + # self-loop is added by GCNConv layer + return get_full_graph_edge_index(num_nodes, self_loop=False).to(td.device) + + class GCNEncoder(nn.Module): - """Graph Convolutional Network to encode embeddings with a series of GCN layers + """Graph Convolutional Network to encode embeddings with a series of GCN + layers from the pytorch geometric package Args: embed_dim: dimension of the embeddings @@ -30,15 +41,19 @@ def __init__( self, env_name: str, embed_dim: int, - num_nodes: int, num_layers: int, init_embedding: nn.Module = None, - self_loop: bool = False, residual: bool = True, + edge_idx_fn: EdgeIndexFnSignature = None, + dropout: float = 0.5, + bias: bool = True, ): - super(GCNEncoder, self).__init__() + super().__init__() self.env_name = env_name + self.embed_dim = embed_dim + self.residual = residual + self.dropout = dropout self.init_embedding = ( env_init_embedding(self.env_name, {"embed_dim": embed_dim}) @@ -46,19 +61,17 @@ def __init__( else init_embedding ) - # Generate edge index for a fully connected graph - self.edge_index = get_full_graph_edge_index(num_nodes, self_loop) + if edge_idx_fn is None: + log.warning("No edge indices passed. Assume a fully connected graph") + edge_idx_fn = edge_idx_fn_wrapper + + self.edge_idx_fn = edge_idx_fn # Define the GCN layers self.gcn_layers = nn.ModuleList( - [GCNConv(embed_dim, embed_dim) for _ in range(num_layers)] + [GCNConv(embed_dim, embed_dim, bias=bias) for _ in range(num_layers)] ) - # Record parameters - self.residual = residual - self.self_loop = self_loop - - # def forward(self, x, node_feature, mask=None): def forward( self, td: TensorDict, mask: Union[Tensor, None] = None ) -> Tuple[Tensor, Tensor]: @@ -75,35 +88,27 @@ def forward( """ # Transfer to embedding space init_h = self.init_embedding(td) - num_node = init_h.size(-2) - - # Check to update the edge index with different number of node - if num_node != self.edge_index.max().item() + 1: - edge_index = get_full_graph_edge_index(num_node, self.self_loop).to( - init_h.device - ) - else: - edge_index = self.edge_index.to(init_h.device) - - # Create the batched graph - data_list = [Data(x=x, edge_index=edge_index) for x in init_h] - data_batch = Batch.from_data_list(data_list) + bs, num_nodes, emb_dim = init_h.shape + # (bs*num_nodes, emb_dim) + update_node_feature = init_h.reshape(-1, emb_dim) + # shape=(2, num_edges) + edge_index = self.edge_idx_fn(td, num_nodes) - # GCN process - update_node_feature = data_batch.x - edge_index = data_batch.edge_index for layer in self.gcn_layers[:-1]: update_node_feature = layer(update_node_feature, edge_index) update_node_feature = F.relu(update_node_feature) - update_node_feature = F.dropout(update_node_feature, training=self.training) + update_node_feature = F.dropout( + update_node_feature, training=self.training, p=self.dropout + ) + # last layer without relu activation and dropout update_node_feature = self.gcn_layers[-1](update_node_feature, edge_index) # De-batch the graph - input_size = init_h.size() - update_node_feature = update_node_feature.view(*input_size) + update_node_feature = update_node_feature.view(bs, num_nodes, emb_dim) # Residual - update_node_feature = update_node_feature + init_h + if self.residual: + update_node_feature = update_node_feature + init_h return update_node_feature, init_h diff --git a/rl4co/models/zoo/hetgnn/encoder.py b/rl4co/models/nn/graph/hgnn.py similarity index 80% rename from rl4co/models/zoo/hetgnn/encoder.py rename to rl4co/models/nn/graph/hgnn.py index 6f966cf8..bd4ce0d2 100644 --- a/rl4co/models/zoo/hetgnn/encoder.py +++ b/rl4co/models/nn/graph/hgnn.py @@ -8,7 +8,7 @@ from torch import Tensor from rl4co.models.nn.env_embeddings import env_init_embedding -from rl4co.models.nn.ops import Normalization +from rl4co.models.nn.ops import TransformerFFN class HetGNNLayer(nn.Module): @@ -75,24 +75,24 @@ def forward( cross_emb = einsum(cross_attn_scores, other_emb_aug, "b m o, b m o e -> b m e") self_emb = self_emb * self_attn_scores # (bs, n_ma, emb_dim) - hidden = torch.sigmoid(cross_emb + self_emb) + hidden = cross_emb + self_emb return hidden class HetGNNBlock(nn.Module): - def __init__(self, embed_dim) -> None: + def __init__(self, embed_dim, normalization: str = "batch") -> None: super().__init__() - self.norm1 = Normalization(embed_dim, normalization="batch") - self.norm2 = Normalization(embed_dim, normalization="batch") self.hgnn1 = HetGNNLayer(embed_dim) self.hgnn2 = HetGNNLayer(embed_dim) + self.ffn1 = TransformerFFN(embed_dim, embed_dim * 2, normalization=normalization) + self.ffn2 = TransformerFFN(embed_dim, embed_dim * 2, normalization=normalization) def forward(self, x1, x2, edge_emb, edges): h1 = self.hgnn1(x1, x2, edge_emb, edges) - h1 = self.norm1(h1 + x1) + h1 = self.ffn1(h1, x1) h2 = self.hgnn2(x2, x1, edge_emb.transpose(1, 2), edges.transpose(1, 2)) - h2 = self.norm2(h2 + x2) + h2 = self.ffn2(h2, x2) return h1, h2 @@ -102,27 +102,28 @@ def __init__( self, embed_dim: int, num_layers: int = 2, + normalization: str = "batch", init_embedding=None, - edge_key: str = "ops_ma_adj", - edge_weights_key: str = "proc_times", - linear_bias: bool = False, + env_name: str = "fjsp", + **init_embedding_kwargs, ) -> None: super().__init__() if init_embedding is None: - init_embedding = env_init_embedding("fjsp", {"embed_dim": embed_dim}) - self.init_embedding = init_embedding + init_embedding_kwargs["embed_dim"] = embed_dim + init_embedding = env_init_embedding(env_name, init_embedding_kwargs) - self.edge_key = edge_key - self.edge_weights_key = edge_weights_key + self.init_embedding = init_embedding self.num_layers = num_layers - self.layers = nn.ModuleList([HetGNNBlock(embed_dim) for _ in range(num_layers)]) + self.layers = nn.ModuleList( + [HetGNNBlock(embed_dim, normalization) for _ in range(num_layers)] + ) def forward(self, td): - edges = td[self.edge_key] - bs, n_rows, n_cols = edges.shape - row_emb, col_emb, edge_emb = self.init_embedding(td) + row_emb, col_emb, edge_emb, edges = self.init_embedding(td) + # perform sanity check to validate correct order of row and col embeddings + n_rows, n_cols = edges.shape[1:] assert row_emb.size(1) == n_rows, "incorrect number of row embeddings" assert col_emb.size(1) == n_cols, "incorrect number of column embeddings" diff --git a/rl4co/models/nn/ops.py b/rl4co/models/nn/ops.py index 044ce8cf..f6f774fe 100644 --- a/rl4co/models/nn/ops.py +++ b/rl4co/models/nn/ops.py @@ -1,8 +1,12 @@ +import math + from typing import Tuple, Union import torch import torch.nn as nn +from rl4co.utils.ops import gather_by_index + class SkipConnection(nn.Module): def __init__(self, module): @@ -79,3 +83,55 @@ def forward(self, hidden: torch.Tensor, seq_pos) -> torch.Tensor: ) hidden = hidden + pes return self.dropout(hidden) + + +class TransformerFFN(nn.Module): + def __init__(self, embed_dim, feed_forward_hidden, normalization="batch") -> None: + super().__init__() + + self.ops = nn.ModuleDict( + { + "norm1": Normalization(embed_dim, normalization), + "ffn": nn.Sequential( + nn.Linear(embed_dim, feed_forward_hidden), + nn.ReLU(), + nn.Linear(feed_forward_hidden, embed_dim), + ), + "norm2": Normalization(embed_dim, normalization), + } + ) + + def forward(self, x, x_old): + x = self.ops["norm1"](x_old + x) + x = self.ops["norm2"](x + self.ops["ffn"](x)) + + return x + + +class RandomEncoding(nn.Module): + """This is like torch.nn.Embedding but with rows of embeddings are randomly + permuted in each forward pass before lookup operation. This might be useful + in cases where classes have no fixed meaning but rather indicate a connection + between different elements in a sequence. Reference is the MatNet model. + """ + + def __init__(self, embed_dim: int, max_classes: int = 100): + super().__init__() + self.embed_dim = embed_dim + self.max_classes = max_classes + rand_emb = torch.rand(max_classes, self.embed_dim) + self.register_buffer("emb", rand_emb) + + def forward(self, hidden: torch.Tensor, classes=None) -> torch.Tensor: + b, s, _ = hidden.shape + if classes is None: + classes = torch.eye(s).unsqueeze(0).expand(b, s) + assert ( + classes.max() < self.max_classes + ), "number of classes larger than embedding table" + classes = classes.unsqueeze(-1).expand(-1, -1, self.embed_dim) + rand_idx = torch.rand(b, self.max_classes).argsort(dim=1) + embs_permuted = self.emb[rand_idx] + rand_emb = gather_by_index(embs_permuted, classes, dim=1) + hidden = hidden + rand_emb + return hidden diff --git a/rl4co/models/rl/__init__.py b/rl4co/models/rl/__init__.py index c78ff706..1a3bf7e2 100644 --- a/rl4co/models/rl/__init__.py +++ b/rl4co/models/rl/__init__.py @@ -2,4 +2,5 @@ from rl4co.models.rl.common.base import RL4COLitModule from rl4co.models.rl.ppo.n_step_ppo import n_step_PPO from rl4co.models.rl.ppo.ppo import PPO +from rl4co.models.rl.ppo.stepwise_ppo import StepwisePPO from rl4co.models.rl.reinforce.reinforce import REINFORCE diff --git a/rl4co/models/rl/common/utils.py b/rl4co/models/rl/common/utils.py new file mode 100644 index 00000000..b23149f7 --- /dev/null +++ b/rl4co/models/rl/common/utils.py @@ -0,0 +1,46 @@ +import torch + + +class RewardScaler: + """This class calculates the running mean and variance of a stepwise observed + quantity, like the RL reward / advantage using the Welford online algorithm. + The mean and variance are either used to standardize the input (scale='norm') or + to scale it (scale='scale'). + + Args: + scale: None | 'scale' | 'mean': specifies how to transform the input; defaults to None + """ + + def __init__(self, scale: str = None): + self.scale = scale + self.count = 0 + self.mean = 0 + self.M2 = 0 + + def __call__(self, scores: torch.Tensor): + if self.scale is None: + return scores + # Score scaling + self.update(scores) + tensor_to_kwargs = dict(dtype=scores.dtype, device=scores.device) + std = (self.M2 / (self.count - 1)).float().sqrt() + score_scaling_factor = std.to(**tensor_to_kwargs) + torch.finfo(scores.dtype).eps + if self.scale == "norm": + scores = (scores - self.mean.to(**tensor_to_kwargs)) / score_scaling_factor + elif self.scale == "scale": + scores /= score_scaling_factor + else: + raise ValueError("unknown scaling operation requested: %s" % self.scale) + return scores + + @torch.no_grad() + def update(self, batch: torch.Tensor): + batch = batch.reshape(-1) + self.count += len(batch) + + # newvalues - oldMean + delta = batch - self.mean + self.mean += (delta / self.count).sum() + # newvalues - newMeant + delta2 = batch - self.mean + self.M2 += (delta * delta2).sum() diff --git a/rl4co/models/rl/ppo/stepwise_ppo.py b/rl4co/models/rl/ppo/stepwise_ppo.py new file mode 100644 index 00000000..98186ea1 --- /dev/null +++ b/rl4co/models/rl/ppo/stepwise_ppo.py @@ -0,0 +1,168 @@ +import copy + +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torchrl.data.replay_buffers import ( + LazyTensorStorage, + ListStorage, + SamplerWithoutReplacement, + TensorDictReplayBuffer, +) + +from rl4co.envs.common.base import RL4COEnvBase +from rl4co.models.rl.common.base import RL4COLitModule +from rl4co.models.rl.common.utils import RewardScaler +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +def make_replay_buffer(buffer_size, batch_size, device="cpu"): + if device == "cpu": + storage = LazyTensorStorage(buffer_size, device="cpu") + else: + storage = ListStorage(buffer_size) + return TensorDictReplayBuffer( + storage=storage, + batch_size=batch_size, + sampler=SamplerWithoutReplacement(drop_last=True), + ) + + +class StepwisePPO(RL4COLitModule): + def __init__( + self, + env: RL4COEnvBase, + policy: nn.Module, + clip_range: float = 0.2, # epsilon of PPO + update_timestep: int = 1, + buffer_size: int = 100_000, + ppo_epochs: int = 2, # inner epoch, K + batch_size: int = 256, + mini_batch_size: int = 256, + vf_lambda: float = 0.5, # lambda of Value function fitting + entropy_lambda: float = 0.01, # lambda of entropy bonus + max_grad_norm: float = 0.5, # max gradient norm + buffer_storage_device: str = "gpu", + metrics: dict = { + "train": ["loss", "surrogate_loss", "value_loss", "entropy"], + }, + reward_scale: str = None, + **kwargs, + ): + super().__init__(env, policy, metrics=metrics, batch_size=batch_size, **kwargs) + + self.policy_old = copy.deepcopy(self.policy) + self.automatic_optimization = False # PPO uses custom optimization routine + self.rb = make_replay_buffer(buffer_size, mini_batch_size, buffer_storage_device) + self.scaler = RewardScaler(reward_scale) + + self.ppo_cfg = { + "clip_range": clip_range, + "ppo_epochs": ppo_epochs, + "update_timestep": update_timestep, + "mini_batch_size": mini_batch_size, + "vf_lambda": vf_lambda, + "entropy_lambda": entropy_lambda, + "max_grad_norm": max_grad_norm, + } + + def update(self, device): + outs = [] + # PPO inner epoch + for _ in range(self.ppo_cfg["ppo_epochs"]): + for sub_td in self.rb: + sub_td = sub_td.to(device) + previous_reward = sub_td["reward"].view(-1, 1) + previous_logp = sub_td["logprobs"] + + logprobs, value_pred, entropy = self.policy.evaluate(sub_td) + + ratios = torch.exp(logprobs - previous_logp) + + advantages = torch.squeeze(previous_reward - value_pred.detach(), 1) + surr1 = ratios * advantages + surr2 = ( + torch.clamp( + ratios, + 1 - self.ppo_cfg["clip_range"], + 1 + self.ppo_cfg["clip_range"], + ) + * advantages + ) + surrogate_loss = -torch.min(surr1, surr2).mean() + + # compute value function loss + value_loss = F.mse_loss(value_pred, previous_reward) + + # compute total loss + loss = ( + surrogate_loss + + self.ppo_cfg["vf_lambda"] * value_loss + - self.ppo_cfg["entropy_lambda"] * entropy.mean() + ) + + # perform manual optimization following the Lightning routine + # https://lightning.ai/docs/pytorch/stable/common/optimization.html + + opt = self.optimizers() + opt.zero_grad() + self.manual_backward(loss) + if self.ppo_cfg["max_grad_norm"] is not None: + self.clip_gradients( + opt, + gradient_clip_val=self.ppo_cfg["max_grad_norm"], + gradient_clip_algorithm="norm", + ) + opt.step() + + out = { + "reward": previous_reward.mean(), + "loss": loss, + "surrogate_loss": surrogate_loss, + "value_loss": value_loss, + "entropy": entropy.mean(), + } + + outs.append(out) + # Copy new weights into old policy: + self.policy_old.load_state_dict(self.policy.state_dict()) + outs = {k: torch.stack([dic[k] for dic in outs], dim=0) for k in outs[0]} + return outs + + def shared_step( + self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None + ): + next_td = self.env.reset(batch) + device = next_td.device + if phase == "train": + while not next_td["done"].all(): + with torch.no_grad(): + td = self.policy_old.act(next_td, self.env, phase="train") + + assert self.env._torchrl_mode, "Use torchrl mode in stepwise PPO" + td = self.env.step(td) + next_td = td.pop("next") + reward = self.env.get_reward(next_td, None) + reward = self.scaler(reward) + + td.set("reward", reward) + # add tensordict with action, logprobs and reward information to buffer + self.rb.extend(td) + + # if iter mod x = 0 then update the policy (x = 1 in paper) + if batch_idx % self.ppo_cfg["update_timestep"] == 0: + out = self.update(device) + self.rb.empty() + + else: + out = self.policy.generate( + next_td, self.env, phase=phase, select_best=phase != "train" + ) + + metrics = self.log_metrics(out, phase, dataloader_idx=dataloader_idx) + return {"loss": out.get("loss", None), **metrics} diff --git a/rl4co/models/rl/reinforce/reinforce.py b/rl4co/models/rl/reinforce/reinforce.py index 2a153391..269aaaff 100644 --- a/rl4co/models/rl/reinforce/reinforce.py +++ b/rl4co/models/rl/reinforce/reinforce.py @@ -10,6 +10,7 @@ from rl4co.envs.common.base import RL4COEnvBase from rl4co.models.rl.common.base import RL4COLitModule +from rl4co.models.rl.common.utils import RewardScaler from rl4co.models.rl.reinforce.baselines import REINFORCEBaseline, get_reinforce_baseline from rl4co.utils.lightning import get_lightning_device from rl4co.utils.pylogger import get_pylogger @@ -35,6 +36,7 @@ def __init__( policy: nn.Module, baseline: Union[REINFORCEBaseline, str] = "rollout", baseline_kwargs: dict = {}, + reward_scale: str = None, **kwargs, ): super().__init__(env, policy, **kwargs) @@ -52,6 +54,7 @@ def __init__( if baseline_kwargs != {}: log.warning("baseline_kwargs is ignored when baseline is not a string") self.baseline = baseline + self.advantage_scaler = RewardScaler(reward_scale) def shared_step( self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None @@ -98,6 +101,7 @@ def calculate_loss( # Main loss function advantage = reward - bl_val # advantage = reward - baseline + advantage = self.advantage_scaler(advantage) reinforce_loss = -(advantage * log_likelihood).mean() loss = reinforce_loss + bl_loss policy_out.update( @@ -197,7 +201,7 @@ def load_from_checkpoint( loaded.setup() loaded.post_setup_hook() # load baseline state dict - state_dict = torch.load(checkpoint_path)["state_dict"] + state_dict = torch.load(checkpoint_path, map_location=map_location)["state_dict"] # get only baseline parameters state_dict = {k: v for k, v in state_dict.items() if "baseline" in k} state_dict = {k.replace("baseline.", "", 1): v for k, v in state_dict.items()} diff --git a/rl4co/models/zoo/__init__.py b/rl4co/models/zoo/__init__.py index 8e6dcddd..cb1a21ef 100644 --- a/rl4co/models/zoo/__init__.py +++ b/rl4co/models/zoo/__init__.py @@ -10,13 +10,19 @@ HeterogeneousAttentionModel, HeterogeneousAttentionModelPolicy, ) -from rl4co.models.zoo.hetgnn import HetGNNModel +from rl4co.models.zoo.l2d import ( + L2DAttnPolicy, + L2DModel, + L2DPolicy, + L2DPolicy4PPO, + L2DPPOModel, +) from rl4co.models.zoo.matnet import MatNet, MatNetPolicy from rl4co.models.zoo.mdam import MDAM, MDAMPolicy +from rl4co.models.zoo.mvmoe import MVMoE_AM, MVMoE_POMO from rl4co.models.zoo.n2s import N2S, N2SPolicy from rl4co.models.zoo.nargnn import NARGNNPolicy from rl4co.models.zoo.polynet import PolyNet from rl4co.models.zoo.pomo import POMO from rl4co.models.zoo.ptrnet import PointerNetwork, PointerNetworkPolicy from rl4co.models.zoo.symnco import SymNCO, SymNCOPolicy -from rl4co.models.zoo.mvmoe import MVMoE_POMO, MVMoE_AM diff --git a/rl4co/models/zoo/am/decoder.py b/rl4co/models/zoo/am/decoder.py index 731c076f..61600050 100644 --- a/rl4co/models/zoo/am/decoder.py +++ b/rl4co/models/zoo/am/decoder.py @@ -34,7 +34,7 @@ def fields(self): def batchify(self, num_starts): new_embs = [] for emb in self.fields: - if isinstance(emb, Tensor): + if isinstance(emb, Tensor) or isinstance(emb, TensorDict): new_embs.append(batchify(emb, num_starts)) else: new_embs.append(emb) @@ -108,7 +108,9 @@ def __init__( if pointer is None: # MHA with Pointer mechanism (https://arxiv.org/abs/1506.03134) - pointer_attn_class = PointerAttention if moe_kwargs is None else PointerAttnMoE + pointer_attn_class = ( + PointerAttention if moe_kwargs is None else PointerAttnMoE + ) pointer = pointer_attn_class( embed_dim, num_heads, diff --git a/rl4co/models/zoo/hetgnn/__init__.py b/rl4co/models/zoo/hetgnn/__init__.py deleted file mode 100644 index f98562b4..00000000 --- a/rl4co/models/zoo/hetgnn/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .model import HetGNNModel diff --git a/rl4co/models/zoo/hetgnn/decoder.py b/rl4co/models/zoo/hetgnn/decoder.py deleted file mode 100644 index 68bf1d36..00000000 --- a/rl4co/models/zoo/hetgnn/decoder.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch -import torch.nn as nn - -from rl4co.models.common.constructive.autoregressive import AutoregressiveDecoder -from rl4co.models.nn.mlp import MLP -from rl4co.utils.ops import batchify, gather_by_index - - -class HetGNNDecoder(AutoregressiveDecoder): - def __init__( - self, embed_dim, feed_forward_hidden_dim: int = 64, feed_forward_layers: int = 2 - ) -> None: - super().__init__() - self.mlp = MLP( - input_dim=2 * embed_dim, - output_dim=1, - num_neurons=[feed_forward_hidden_dim] * feed_forward_layers, - ) - self.dummy = nn.Parameter(torch.rand(2 * embed_dim)) - - def pre_decoder_hook(self, td, env, hidden, num_starts): - return td, env, hidden - - def forward(self, td, hidden, num_starts): - if num_starts > 1: - hidden = tuple(map(lambda x: batchify(x, num_starts), hidden)) - - ma_emb, ops_emb = hidden - bs, n_rows, emb_dim = ma_emb.shape - - # (bs, n_jobs, emb) - job_emb = gather_by_index(ops_emb, td["next_op"], squeeze=False) - - # (bs, n_jobs, n_ma, emb) - job_emb_expanded = job_emb.unsqueeze(2).expand(-1, -1, n_rows, -1) - ma_emb_expanded = ma_emb.unsqueeze(1).expand_as(job_emb_expanded) - - # Input of actor MLP - # shape: [bs, num_jobs * n_ma, 2*emb] - h_actions = torch.cat((job_emb_expanded, ma_emb_expanded), dim=-1).flatten(1, 2) - no_ops = self.dummy[None, None].expand(bs, 1, -1) # [bs, 1, 2*emb_dim] - # [bs, num_jobs * n_ma + 1, 2*emb_dim] - h_actions_w_noop = torch.cat((no_ops, h_actions), 1) - - # (b, j*m) - mask = td["action_mask"] - - # (b, j*m) - logits = self.mlp(h_actions_w_noop).squeeze(-1) - - return logits, mask diff --git a/rl4co/models/zoo/hetgnn/model.py b/rl4co/models/zoo/hetgnn/model.py deleted file mode 100644 index 40f27de2..00000000 --- a/rl4co/models/zoo/hetgnn/model.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Union - -from rl4co.envs.common.base import RL4COEnvBase -from rl4co.models.rl import REINFORCE -from rl4co.models.rl.reinforce.baselines import REINFORCEBaseline - -from .policy import HetGNNPolicy - - -class HetGNNModel(REINFORCE): - """Heterogenous Graph Neural Network Model as described by Song et al. (2022): - 'Flexible Job Shop Scheduling via Graph Neural Network and Deep Reinforcement Learning' - - Args: - env: Environment to use for the algorithm - policy: Policy to use for the algorithm - baseline: REINFORCE baseline. Defaults to rollout (1 epoch of exponential, then greedy rollout baseline) - policy_kwargs: Keyword arguments for policy - baseline_kwargs: Keyword arguments for baseline - **kwargs: Keyword arguments passed to the superclass - """ - - def __init__( - self, - env: RL4COEnvBase, - policy: HetGNNPolicy = None, - baseline: Union[REINFORCEBaseline, str] = "rollout", - policy_kwargs={}, - baseline_kwargs={}, - **kwargs, - ): - assert ( - env.name == "fjsp" - ), "HetGNNModel currently only works for FJSP (Flexible Job-Shop Scheduling Problem)" - if policy is None: - policy = HetGNNPolicy(env_name=env.name, **policy_kwargs) - - super().__init__(env, policy, baseline, baseline_kwargs, **kwargs) diff --git a/rl4co/models/zoo/hetgnn/policy.py b/rl4co/models/zoo/hetgnn/policy.py deleted file mode 100644 index c51dc30e..00000000 --- a/rl4co/models/zoo/hetgnn/policy.py +++ /dev/null @@ -1,99 +0,0 @@ -from typing import Optional - -import torch.nn as nn - -from rl4co.models.common.constructive.autoregressive import ( - AutoregressiveDecoder, - AutoregressiveEncoder, - AutoregressivePolicy, -) -from rl4co.utils.pylogger import get_pylogger - -from .decoder import HetGNNDecoder -from .encoder import HetGNNEncoder - -log = get_pylogger(__name__) - - -class HetGNNPolicy(AutoregressivePolicy): - """ - Base Non-autoregressive policy for NCO construction methods. - This creates a heatmap of NxN for N nodes (i.e., heuristic) that models the probability to go from one node to another for all nodes. - - The policy performs the following steps: - 1. Encode the environment initial state into node embeddings - 2. Decode (non-autoregressively) to construct the solution to the NCO problem - - Warning: - The effectiveness of the non-autoregressive approach can vary significantly across different problem types and configurations. - It may require careful tuning of the model architecture and decoding strategy to achieve competitive results. - - Args: - encoder: Encoder module. Can be passed by sub-classes - decoder: Decoder module. Note that this moule defaults to the non-autoregressive decoder - embed_dim: Dimension of the embeddings - env_name: Name of the environment used to initialize embeddings - init_embedding: Model to use for the initial embedding. If None, use the default embedding for the environment - edge_embedding: Model to use for the edge embedding. If None, use the default embedding for the environment - graph_network: Model to use for the graph network. If None, use the default embedding for the environment - heatmap_generator: Model to use for the heatmap generator. If None, use the default embedding for the environment - num_layers_heatmap_generator: Number of layers in the heatmap generator - num_layers_graph_encoder: Number of layers in the graph encoder - act_fn: Activation function to use in the encoder - agg_fn: Aggregation function to use in the encoder - linear_bias: Whether to use bias in the encoder - train_decode_type: Type of decoding during training - val_decode_type: Type of decoding during validation - test_decode_type: Type of decoding during testing - **constructive_policy_kw: Unused keyword arguments - """ - - def __init__( - self, - encoder: Optional[AutoregressiveEncoder] = None, - decoder: Optional[AutoregressiveDecoder] = None, - embed_dim: int = 64, - num_encoder_layers: int = 2, - env_name: str = "fjsp", - init_embedding: Optional[nn.Module] = None, - linear_bias: bool = True, - train_decode_type: str = "sampling", - val_decode_type: str = "greedy", - test_decode_type: str = "multistart_sampling", - **constructive_policy_kw, - ): - if len(constructive_policy_kw) > 0: - log.warn(f"Unused kwargs: {constructive_policy_kw}") - - if encoder is None: - encoder = HetGNNEncoder( - embed_dim=embed_dim, - num_layers=num_encoder_layers, - init_embedding=init_embedding, - linear_bias=linear_bias, - ) - - # The decoder generates logits given the current td and heatmap - if decoder is None: - decoder = HetGNNDecoder( - embed_dim=embed_dim, - feed_forward_hidden_dim=embed_dim, - feed_forward_layers=2, - ) - else: - # check if the decoder has trainable parameters - if any(p.requires_grad for p in decoder.parameters()): - log.error( - "The decoder contains trainable parameters. This should not happen in a non-autoregressive policy." - ) - - # Pass to constructive policy - super(HetGNNPolicy, self).__init__( - encoder=encoder, - decoder=decoder, - env_name=env_name, - train_decode_type=train_decode_type, - val_decode_type=val_decode_type, - test_decode_type=test_decode_type, - **constructive_policy_kw, - ) diff --git a/rl4co/models/zoo/l2d/__init__.py b/rl4co/models/zoo/l2d/__init__.py new file mode 100644 index 00000000..398dea98 --- /dev/null +++ b/rl4co/models/zoo/l2d/__init__.py @@ -0,0 +1,2 @@ +from .model import L2DModel, L2DPPOModel +from .policy import L2DAttnPolicy, L2DPolicy, L2DPolicy4PPO diff --git a/rl4co/models/zoo/l2d/decoder.py b/rl4co/models/zoo/l2d/decoder.py new file mode 100644 index 00000000..b0ab3041 --- /dev/null +++ b/rl4co/models/zoo/l2d/decoder.py @@ -0,0 +1,390 @@ +import abc + +from typing import Any, Tuple + +import torch +import torch.nn as nn + +from einops import einsum, rearrange +from tensordict import TensorDict +from torch import Tensor + +from rl4co.models.common.constructive.autoregressive import AutoregressiveDecoder +from rl4co.models.nn.attention import PointerAttention +from rl4co.models.nn.env_embeddings.context import SchedulingContext +from rl4co.models.nn.env_embeddings.dynamic import JSSPDynamicEmbedding +from rl4co.models.nn.graph.hgnn import HetGNNEncoder +from rl4co.models.nn.mlp import MLP +from rl4co.models.zoo.am.decoder import AttentionModelDecoder, PrecomputedCache +from rl4co.utils.ops import batchify, gather_by_index + +from .encoder import GCN4JSSP + + +class L2DActor(nn.Module, metaclass=abc.ABCMeta): + """Base decoder model for actor in L2D. The actor is responsible for generating the logits for the action + similar to the decoder in autoregressive models. Since the decoder in L2D can have the additional purpose + of extracting features (i.e. encoding the environment in ever iteration), we need an additional actor class. + This function serves as template for such actor classes in L2D + """ + + @abc.abstractmethod + def forward( + self, td: TensorDict, hidden: Any = None, num_starts: int = 0 + ) -> Tuple[Tensor, Tensor]: + """Obtain logits for current action to the next ones + + Args: + td: TensorDict containing the input data + hidden: Hidden state from the encoder. Can be any type + num_starts: Number of starts for multistart decoding + + Returns: + Tuple containing the logits and the action mask + """ + raise NotImplementedError("Implement me in subclass!") + + def pre_decoder_hook( + self, td: TensorDict, env=None, hidden: Any = None, num_starts: int = 0 + ) -> Tuple[TensorDict, Any]: + """By default, we only require the input for the actor to be a tuple + (in JSSP we only have operation embeddings but in FJSP we have operation + and machine embeddings. By expecting a tuple we can generalize things.) + + Args: + td: TensorDict containing the input data + hidden: Hidden state from the encoder + num_starts: Number of starts for multistart decoding + + Returns: + Tuple containing the updated hidden state(s) and the input TensorDict + """ + + hidden = (hidden,) if not isinstance(hidden, tuple) else hidden + + if num_starts > 1: + # NOTE: when using pomo, we need this + hidden = tuple(map(lambda x: batchify(x, num_starts), hidden)) + + return td, env, hidden + + +class JSSPActor(L2DActor): + def __init__( + self, + embed_dim: int, + hidden_dim: int, + hidden_layers: int = 2, + het_emb: bool = False, + check_nan: bool = True, + ) -> None: + super().__init__() + + input_dim = (1 + int(het_emb)) * embed_dim + self.mlp = MLP( + input_dim=input_dim, + output_dim=1, + num_neurons=[hidden_dim] * hidden_layers, + hidden_act="ReLU", + out_act="Identity", + input_norm="None", + output_norm="None", + ) + self.het_emb = het_emb + self.dummy = nn.Parameter(torch.rand(input_dim)) + self.check_nan = check_nan + + def forward(self, td, op_emb, ma_emb=None): + bs = td.size(0) + # (bs, n_j) + next_op = td["next_op"] + # (bs, n_j, emb) + job_emb = gather_by_index(op_emb, next_op, dim=1) + if ma_emb is not None: + ma_emb_per_op = einsum(td["ops_ma_adj"], ma_emb, "b m o, b m e -> b o e") + # (bs, n_j, emb) + ma_emb_per_job = gather_by_index(ma_emb_per_op, next_op, dim=1) + # (bs, n_j, 2 * emb) + job_emb = torch.cat((job_emb, ma_emb_per_job), dim=2) + # (bs, n_j, 2 * emb) + no_ops = self.dummy[None, None].expand(bs, 1, -1) + # (bs, 1 + n_j, 2 * emb) + all_actions = torch.cat((no_ops, job_emb), 1) + # (bs, 1 + n_j) + logits = self.mlp(all_actions).squeeze(2) + + if self.check_nan: + assert not torch.isnan(logits).any(), "Logits contain NaNs" + + # (b, 1 + j) + mask = td["action_mask"] + + return logits, mask + + +class FJSPActor(L2DActor): + def __init__( + self, + embed_dim: int, + hidden_dim: int, + hidden_layers: int = 2, + check_nan: bool = True, + ) -> None: + super().__init__() + self.mlp = MLP( + input_dim=2 * embed_dim, + output_dim=1, + num_neurons=[hidden_dim] * hidden_layers, + hidden_act="ReLU", + out_act="Identity", + input_norm="None", + output_norm="None", + ) + self.dummy = nn.Parameter(torch.rand(2 * embed_dim)) + self.check_nan = check_nan + + def forward(self, td, ops_emb, ma_emb): + bs, n_ma = ma_emb.shape[:2] + # (bs, n_jobs, emb) + job_emb = gather_by_index(ops_emb, td["next_op"], squeeze=False) + # (bs, n_jobs, n_ma, emb) + job_emb_expanded = job_emb.unsqueeze(2).expand(-1, -1, n_ma, -1) + ma_emb_expanded = ma_emb.unsqueeze(1).expand_as(job_emb_expanded) + # (bs, num_jobs * n_ma, 2*emb) + h_actions = torch.cat((job_emb_expanded, ma_emb_expanded), dim=-1).flatten(1, 2) + # (bs, 1, 2*emb_dim) + no_ops = self.dummy[None, None].expand(bs, 1, -1) + # (bs, num_jobs * n_ma + 1, 2*emb_dim) + h_actions_w_noop = torch.cat((no_ops, h_actions), 1) + # (b, j*m) + logits = self.mlp(h_actions_w_noop).squeeze(-1) + + if self.check_nan: + assert not torch.isnan(logits).any(), "Logits contain NaNs" + # (b, 1 + j) + mask = td["action_mask"] + return logits, mask + + +class L2DDecoder(AutoregressiveDecoder): + # feature extractor + actor + def __init__( + self, + env_name: str = "jssp", + feature_extractor: nn.Module = None, + actor: nn.Module = None, + init_embedding: nn.Module = None, + embed_dim: int = 128, + actor_hidden_dim: int = 128, + actor_hidden_layers: int = 2, + num_encoder_layers: int = 3, + num_heads: int = 8, + normalization: str = "batch", + het_emb: bool = False, + stepwise: bool = False, + scaling_factor: int = 1000, + ): + super(L2DDecoder, self).__init__() + + if feature_extractor is None and stepwise: + if env_name == "fjsp" or (het_emb and env_name == "jssp"): + feature_extractor = HetGNNEncoder( + env_name=env_name, + embed_dim=embed_dim, + num_layers=num_encoder_layers, + normalization=normalization, + init_embedding=init_embedding, + scaling_factor=scaling_factor, + ) + else: + feature_extractor = GCN4JSSP( + embed_dim, + num_encoder_layers, + init_embedding=init_embedding, + scaling_factor=scaling_factor, + ) + + self.feature_extractor = feature_extractor + + if actor is None: + if env_name == "fjsp": + actor = FJSPActor( + embed_dim=embed_dim, + hidden_dim=actor_hidden_dim, + hidden_layers=actor_hidden_layers, + ) + else: + actor = JSSPActor( + embed_dim=embed_dim, + hidden_dim=actor_hidden_dim, + hidden_layers=actor_hidden_layers, + het_emb=het_emb, + ) + + self.actor = actor + + def forward(self, td, hidden, num_starts): + if hidden is None: + # NOTE in case we have multiple starts, td is batchified + # (through decoding strategy pre decoding hook). Thus the + # embeddings from feature_extractor have the correct shape + num_starts = 0 + # (bs, n_j * n_ops, e), (bs, n_m, e) + hidden, _ = self.feature_extractor(td) + + td, _, hidden = self.actor.pre_decoder_hook(td, None, hidden, num_starts) + + # (bs, n_j, e) + logits, mask = self.actor(td, *hidden) + + return logits, mask + + +class L2DAttnPointer(PointerAttention): + def __init__( + self, + env_name: str, + embed_dim: int, + num_heads: int, + out_bias: bool = False, + check_nan: bool = True, + ): + super().__init__( + embed_dim=embed_dim, + num_heads=num_heads, + mask_inner=False, + out_bias=out_bias, + check_nan=check_nan, + ) + self.env_name = env_name + + def forward(self, query, key, value, logit_key, attn_mask=None): + # bs = query.size(0) + # (b m j) + logits = super().forward(query, key, value, logit_key, attn_mask=attn_mask) + if self.env_name == "jssp": + # (b j) + logits = logits.sum(1) + elif self.env_name == "fjsp": + no_op_logits = logits[..., 0].sum(1, keepdims=True) + logits = rearrange(logits[..., 1:], "b m j -> b (j m)") + logits = torch.cat((no_op_logits, logits), dim=1) + + return logits + + +class AttnActor(AttentionModelDecoder): + def __init__( + self, + embed_dim: int = 128, + num_heads: int = 8, + env_name: str = "tsp", + context_embedding: nn.Module = None, + dynamic_embedding: nn.Module = None, + mask_inner: bool = True, + out_bias_pointer_attn: bool = False, + linear_bias: bool = False, + use_graph_context: bool = True, + check_nan: bool = True, + sdpa_fn: callable = None, + pointer: nn.Module = None, + moe_kwargs: dict = None, + ): + super().__init__( + embed_dim, + num_heads, + env_name, + context_embedding, + dynamic_embedding, + mask_inner, + out_bias_pointer_attn, + linear_bias, + use_graph_context, + check_nan, + sdpa_fn, + pointer, + moe_kwargs, + ) + + def pre_decoder_hook( + self, td: TensorDict, env=None, hidden: Any = None, num_starts: int = 0 + ) -> Tuple[TensorDict, Any]: + cache = self._precompute_cache(hidden, num_starts=num_starts) + return td, env, (cache,) + + +class L2DAttnActor(AttnActor): + def __init__( + self, + embed_dim: int = 128, + num_heads: int = 8, + env_name: str = "jssp", + scaling_factor: int = 1000, + stepwise: bool = False, + ): + context_embedding = SchedulingContext(embed_dim, scaling_factor=scaling_factor) + if stepwise: + # in a stepwise encoding setting, the embeddings contain all current information + dynamic_embedding = None + else: + # otherwise we might want to update the static embeddings using dynamic updates + dynamic_embedding = JSSPDynamicEmbedding( + embed_dim, scaling_factor=scaling_factor + ) + pointer = L2DAttnPointer(env_name, embed_dim, num_heads, check_nan=False) + + super().__init__( + embed_dim=embed_dim, + num_heads=num_heads, + env_name=env_name, + context_embedding=context_embedding, + dynamic_embedding=dynamic_embedding, + pointer=pointer, + ) + self.dummy = nn.Parameter(torch.rand(1, embed_dim)) + + def _compute_q(self, cached: PrecomputedCache, td: TensorDict): + embeddings = cached.node_embeddings + ma_embs = embeddings["machine_embeddings"] + return self.context_embedding(ma_embs, td) + + def _compute_kvl(self, cached: PrecomputedCache, td: TensorDict): + glimpse_k_stat, glimpse_v_stat, logit_k_stat = ( + gather_by_index(cached.glimpse_key, td["next_op"], dim=1), + gather_by_index(cached.glimpse_val, td["next_op"], dim=1), + gather_by_index(cached.logit_key, td["next_op"], dim=1), + ) + # Compute dynamic embeddings and add to static embeddings + glimpse_k_dyn, glimpse_v_dyn, logit_k_dyn = self.dynamic_embedding(td, cached) + glimpse_k = glimpse_k_stat + glimpse_k_dyn + glimpse_v = glimpse_v_stat + glimpse_v_dyn + logit_k = logit_k_stat + logit_k_dyn + + no_ops = self.dummy.unsqueeze(1).expand(td.size(0), 1, -1).to(logit_k) + logit_k = torch.cat((no_ops, logit_k), dim=1) + + return glimpse_k, glimpse_v, logit_k + + def _precompute_cache(self, embeddings: Tuple[torch.Tensor, torch.Tensor], **kwargs): + ops_emb, ma_emb = embeddings + + ( + glimpse_key_fixed, + glimpse_val_fixed, + logit_key, + ) = self.project_node_embeddings( + ops_emb + ).chunk(3, dim=-1) + + embeddings = TensorDict( + {"op_embeddings": ops_emb, "machine_embeddings": ma_emb}, + batch_size=ops_emb.size(0), + ) + # Organize in a dataclass for easy access + return PrecomputedCache( + node_embeddings=embeddings, + graph_context=0, + glimpse_key=glimpse_key_fixed, + glimpse_val=glimpse_val_fixed, + logit_key=logit_key, + ) diff --git a/rl4co/models/zoo/l2d/encoder.py b/rl4co/models/zoo/l2d/encoder.py new file mode 100644 index 00000000..0bc43fa9 --- /dev/null +++ b/rl4co/models/zoo/l2d/encoder.py @@ -0,0 +1,26 @@ +from rl4co.models.nn.env_embeddings.init import JSSPInitEmbedding +from rl4co.models.nn.graph.gcn import GCNEncoder +from rl4co.utils.ops import adj_to_pyg_edge_index + + +class GCN4JSSP(GCNEncoder): + def __init__( + self, + embed_dim: int, + num_layers: int, + init_embedding=None, + **init_embedding_kwargs, + ): + def edge_idx_fn(td, _): + return adj_to_pyg_edge_index(td["adjacency"]) + + if init_embedding is None: + init_embedding = JSSPInitEmbedding(embed_dim, **init_embedding_kwargs) + + super().__init__( + env_name="jssp", + embed_dim=embed_dim, + num_layers=num_layers, + edge_idx_fn=edge_idx_fn, + init_embedding=init_embedding, + ) diff --git a/rl4co/models/zoo/l2d/model.py b/rl4co/models/zoo/l2d/model.py new file mode 100644 index 00000000..b70784b1 --- /dev/null +++ b/rl4co/models/zoo/l2d/model.py @@ -0,0 +1,69 @@ +from typing import Union + +from rl4co.envs.common.base import RL4COEnvBase +from rl4co.models.rl import REINFORCE, StepwisePPO +from rl4co.models.rl.reinforce.baselines import REINFORCEBaseline + +from .policy import L2DPolicy, L2DPolicy4PPO + + +class L2DPPOModel(StepwisePPO): + """Learning2Dispatch model by Zhang et al. (2020): + 'Learning to Dispatch for Job Shop Scheduling via Deep Reinforcement Learning' + + Args: + env: Environment to use for the algorithm + policy: Policy to use for the algorithm + baseline: REINFORCE baseline. Defaults to rollout (1 epoch of exponential, then greedy rollout baseline) + policy_kwargs: Keyword arguments for policy + baseline_kwargs: Keyword arguments for baseline + **kwargs: Keyword arguments passed to the superclass + """ + + def __init__( + self, + env: RL4COEnvBase, + policy: L2DPolicy = None, + policy_kwargs={}, + **kwargs, + ): + assert env.name in [ + "fjsp", + "jssp", + ], "L2DModel currently only works for Job-Shop Scheduling Problems" + if policy is None: + policy = L2DPolicy4PPO(env_name=env.name, **policy_kwargs) + + super().__init__(env, policy, **kwargs) + + +class L2DModel(REINFORCE): + """Learning2Dispatch model by Zhang et al. (2020): + 'Learning to Dispatch for Job Shop Scheduling via Deep Reinforcement Learning' + + Args: + env: Environment to use for the algorithm + policy: Policy to use for the algorithm + baseline: REINFORCE baseline. Defaults to rollout (1 epoch of exponential, then greedy rollout baseline) + policy_kwargs: Keyword arguments for policy + baseline_kwargs: Keyword arguments for baseline + **kwargs: Keyword arguments passed to the superclass + """ + + def __init__( + self, + env: RL4COEnvBase, + policy: L2DPolicy = None, + baseline: Union[REINFORCEBaseline, str] = "rollout", + policy_kwargs={}, + baseline_kwargs={}, + **kwargs, + ): + assert env.name in [ + "fjsp", + "jssp", + ], "L2DModel currently only works for Job-Shop Scheduling Problems" + if policy is None: + policy = L2DPolicy(env_name=env.name, **policy_kwargs) + + super().__init__(env, policy, baseline, baseline_kwargs, **kwargs) diff --git a/rl4co/models/zoo/l2d/policy.py b/rl4co/models/zoo/l2d/policy.py new file mode 100644 index 00000000..b4b9b11c --- /dev/null +++ b/rl4co/models/zoo/l2d/policy.py @@ -0,0 +1,248 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from torch.distributions import Categorical + +from rl4co.models.common.constructive.autoregressive import ( + AutoregressiveDecoder, + AutoregressiveEncoder, + AutoregressivePolicy, +) +from rl4co.models.common.constructive.base import NoEncoder +from rl4co.models.nn.env_embeddings.init import FJSPMatNetInitEmbedding +from rl4co.models.nn.graph.hgnn import HetGNNEncoder +from rl4co.models.nn.mlp import MLP +from rl4co.models.zoo.matnet.matnet_w_sa import Encoder +from rl4co.utils.decoding import DecodingStrategy, process_logits +from rl4co.utils.ops import gather_by_index +from rl4co.utils.pylogger import get_pylogger + +from .decoder import L2DAttnActor, L2DDecoder +from .encoder import GCN4JSSP + +log = get_pylogger(__name__) + + +class L2DPolicy(AutoregressivePolicy): + def __init__( + self, + encoder: Optional[AutoregressiveEncoder] = None, + decoder: Optional[AutoregressiveDecoder] = None, + embed_dim: int = 64, + num_encoder_layers: int = 2, + env_name: str = "fjsp", + het_emb: bool = True, + scaling_factor: int = 1000, + init_embedding: Optional[nn.Module] = None, + stepwise_encoding: bool = False, + tanh_clipping: float = 10, + train_decode_type: str = "sampling", + val_decode_type: str = "greedy", + test_decode_type: str = "multistart_sampling", + **constructive_policy_kw, + ): + if len(constructive_policy_kw) > 0: + log.warn(f"Unused kwargs: {constructive_policy_kw}") + + if encoder is None: + if stepwise_encoding: + encoder = NoEncoder() + elif env_name == "fjsp" or (env_name == "jssp" and het_emb): + encoder = HetGNNEncoder( + env_name=env_name, + embed_dim=embed_dim, + num_layers=num_encoder_layers, + normalization="batch", + init_embedding=init_embedding, + scaling_factor=scaling_factor, + ) + else: + encoder = GCN4JSSP( + embed_dim, + num_encoder_layers, + init_embedding=init_embedding, + scaling_factor=scaling_factor, + ) + + # The decoder generates logits given the current td and heatmap + if decoder is None: + decoder = L2DDecoder( + env_name=env_name, + embed_dim=embed_dim, + actor_hidden_dim=embed_dim, + num_encoder_layers=num_encoder_layers, + init_embedding=init_embedding, + het_emb=het_emb, + stepwise=stepwise_encoding, + scaling_factor=scaling_factor, + ) + + # Pass to constructive policy + super(L2DPolicy, self).__init__( + encoder=encoder, + decoder=decoder, + env_name=env_name, + tanh_clipping=tanh_clipping, + train_decode_type=train_decode_type, + val_decode_type=val_decode_type, + test_decode_type=test_decode_type, + **constructive_policy_kw, + ) + + +class L2DAttnPolicy(AutoregressivePolicy): + def __init__( + self, + encoder: Optional[AutoregressiveEncoder] = None, + decoder: Optional[AutoregressiveDecoder] = None, + embed_dim: int = 256, + num_heads: int = 8, + num_encoder_layers: int = 4, + scaling_factor: int = 1000, + env_name: str = "fjsp", + init_embedding: Optional[nn.Module] = None, + tanh_clipping: float = 10, + train_decode_type: str = "sampling", + val_decode_type: str = "greedy", + test_decode_type: str = "multistart_sampling", + **constructive_policy_kw, + ): + if len(constructive_policy_kw) > 0: + log.warn(f"Unused kwargs: {constructive_policy_kw}") + + if encoder is None: + if init_embedding is None: + init_embedding = FJSPMatNetInitEmbedding( + embed_dim, scaling_factor=scaling_factor + ) + + encoder = Encoder( + embed_dim=embed_dim, + num_heads=num_heads, + num_layers=num_encoder_layers, + normalization="batch", + feedforward_hidden=embed_dim * 2, + init_embedding=init_embedding, + ) + + # The decoder generates logits given the current td and heatmap + if decoder is None: + decoder = L2DAttnActor( + env_name=env_name, + embed_dim=embed_dim, + num_heads=num_heads, + scaling_factor=scaling_factor, + stepwise=False, + ) + + # Pass to constructive policy + super(L2DAttnPolicy, self).__init__( + encoder=encoder, + decoder=decoder, + env_name=env_name, + tanh_clipping=tanh_clipping, + train_decode_type=train_decode_type, + val_decode_type=val_decode_type, + test_decode_type=test_decode_type, + **constructive_policy_kw, + ) + + +class L2DPolicy4PPO(L2DPolicy): + def __init__( + self, + encoder=None, + decoder=None, + critic=None, + embed_dim: int = 64, + num_encoder_layers: int = 2, + env_name: str = "fjsp", + het_emb: bool = True, + scaling_factor: int = 1000, + init_embedding=None, + tanh_clipping: float = 10, + train_decode_type: str = "sampling", + val_decode_type: str = "greedy", + test_decode_type: str = "multistart_sampling", + **constructive_policy_kw, + ): + if init_embedding is None: + pass # TODO PPO specific init emb? + + super().__init__( + encoder=encoder, + decoder=decoder, + embed_dim=embed_dim, + num_encoder_layers=num_encoder_layers, + env_name=env_name, + het_emb=het_emb, + scaling_factor=scaling_factor, + init_embedding=init_embedding, + stepwise_encoding=True, + tanh_clipping=tanh_clipping, + train_decode_type=train_decode_type, + val_decode_type=val_decode_type, + test_decode_type=test_decode_type, + **constructive_policy_kw, + ) + + if critic is None: + if env_name == "fjsp" or het_emb: + input_dim = 2 * embed_dim + else: + input_dim = embed_dim + critic = MLP(input_dim, 1, num_neurons=[embed_dim] * 2) + + self.critic = critic + assert isinstance( + self.encoder, NoEncoder + ), "Define a feature extractor for decoder rather than an encoder in stepwise PPO" + + def evaluate(self, td): + # Encoder: get encoder output and initial embeddings from initial state + hidden, _ = self.decoder.feature_extractor(td) + # pool the embeddings for the critic + h_tuple = (hidden,) if isinstance(hidden, torch.Tensor) else hidden + pooled = tuple(map(lambda x: x.mean(dim=-2), h_tuple)) + # potentially cat multiple embeddings (pooled ops and machines) + h_pooled = torch.cat(pooled, dim=-1) + # pred value via the value head + value_pred = self.critic(h_pooled) + # pre decoder / actor hook + td, _, hidden = self.decoder.actor.pre_decoder_hook( + td, None, hidden, num_starts=0 + ) + logits, mask = self.decoder.actor(td, *hidden) + # get logprobs and entropy over logp distribution + logprobs = process_logits(logits, mask, tanh_clipping=self.tanh_clipping) + action_logprobs = gather_by_index(logprobs, td["action"], dim=1) + dist_entropys = Categorical(logprobs.exp()).entropy() + + return action_logprobs, value_pred, dist_entropys + + def act(self, td, env, phase: str = "train"): + logits, mask = self.decoder(td, hidden=None, num_starts=0) + logprobs = process_logits(logits, mask, tanh_clipping=self.tanh_clipping) + + # DRL-S, sampling actions following \pi + if phase == "train": + action_indexes = DecodingStrategy.sampling(logprobs) + td["logprobs"] = gather_by_index(logprobs, action_indexes, dim=1) + + # DRL-G, greedily picking actions with the maximum probability + else: + action_indexes = DecodingStrategy.greedy(logprobs) + + # memories.states.append(copy.deepcopy(state)) + td["action"] = action_indexes + + return td + + @torch.no_grad() + def generate(self, td, env=None, phase: str = "train", **kwargs) -> dict: + assert phase != "train", "dont use generate() in training mode" + with torch.no_grad(): + out = super().__call__(td, env, phase=phase, **kwargs) + return out diff --git a/rl4co/models/zoo/matnet/decoder.py b/rl4co/models/zoo/matnet/decoder.py index 616d38f9..ad8aef77 100644 --- a/rl4co/models/zoo/matnet/decoder.py +++ b/rl4co/models/zoo/matnet/decoder.py @@ -127,8 +127,8 @@ def __init__( self.cached_embs: PrecomputedCache = None # self.encoded_wait_op = nn.Parameter(torch.rand((1, 1, embed_dim))) - def _precompute_cache(self, embeddings: Tuple[Tensor], td: TensorDict = None): - self.cached_embs = super()._precompute_cache(embeddings, td) + def _precompute_cache(self, embeddings: Tuple[Tensor], **kwargs): + self.cached_embs = super()._precompute_cache(embeddings, **kwargs) def forward( self, @@ -141,7 +141,7 @@ def forward( batch_size = td.size(0) # TODO: we need to insert precompute cache inside the decoder - logits, mask = super().forward(self.cached_embs, td, num_starts) + logits, mask = super().forward(td, self.cached_embs, num_starts) logprobs = process_logits( logits, mask, diff --git a/rl4co/models/zoo/matnet/encoder.py b/rl4co/models/zoo/matnet/encoder.py index 87f8ffec..3ad52309 100644 --- a/rl4co/models/zoo/matnet/encoder.py +++ b/rl4co/models/zoo/matnet/encoder.py @@ -1,41 +1,28 @@ -import math - from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange - +from rl4co.models.nn.attention import MultiHeadCrossAttention from rl4co.models.nn.env_embeddings import env_init_embedding from rl4co.models.nn.ops import Normalization -class MatNetCrossMHA(nn.Module): +class MixedScoresSDPA(nn.Module): def __init__( self, - embed_dim: int, num_heads: int, - bias: bool = False, + num_scores: int = 1, mixer_hidden_dim: int = 16, mix1_init: float = (1 / 2) ** (1 / 2), mix2_init: float = (1 / 16) ** (1 / 2), ): super().__init__() - self.embed_dim = embed_dim self.num_heads = num_heads - assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" - self.head_dim = self.embed_dim // num_heads - - self.Wq = nn.Linear(embed_dim, embed_dim, bias=bias) - self.Wkv = nn.Linear(embed_dim, 2 * embed_dim, bias=bias) - - # Score mixer - # Taken from the official MatNet implementation - # https://github.com/yd-kwon/MatNet/blob/main/ATSP/ATSP_MatNet/ATSPModel_LIB.py#L72 + self.num_scores = num_scores mix_W1 = torch.torch.distributions.Uniform(low=-mix1_init, high=mix1_init).sample( - (num_heads, 2, mixer_hidden_dim) + (num_heads, self.num_scores + 1, mixer_hidden_dim) ) mix_b1 = torch.torch.distributions.Uniform(low=-mix1_init, high=mix1_init).sample( (num_heads, mixer_hidden_dim) @@ -52,38 +39,23 @@ def __init__( self.mix_W2 = nn.Parameter(mix_W2) self.mix_b2 = nn.Parameter(mix_b2) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - - def forward(self, q_input, kv_input, dmat): - """ - - Args: - q_input (Tensor): [b, m, d] - kv_input (Tensor): [b, n, d] - dmat (Tensor): [b, m, n] - - Returns: - Tensor: [b, m, d] - """ - - b, m, n = dmat.shape + def forward(self, q, k, v, attn_mask=None, dmat=None, dropout_p=0.0): + """Scaled Dot-Product Attention with MatNet Scores Mixer""" + assert dmat is not None + b, m, n = dmat.shape[:3] + dmat = dmat.reshape(b, m, n, self.num_scores) - q = rearrange( - self.Wq(q_input), "b m (h d) -> b h m d", h=self.num_heads - ) # [b, h, m, d] - k, v = rearrange( - self.Wkv(kv_input), "b n (two h d) -> two b h n d", two=2, h=self.num_heads - ).unbind( - dim=0 - ) # [b, h, n, d] - - scale = math.sqrt(q.size(-1)) # scale factor - attn_scores = torch.matmul(q, k.transpose(2, 3)) / scale # [b, h, m, n] - mix_attn_scores = torch.stack( - [attn_scores, dmat[:, None, :, :].expand(b, self.num_heads, m, n)], dim=-1 - ) # [b, h, m, n, 2] + # Calculate scaled dot product + attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1) ** 0.5) + mix_attn_scores = torch.cat( + [ + attn_scores.unsqueeze(-1), + dmat[:, None, ...].expand(b, self.num_heads, m, n, self.num_scores), + ], + dim=-1, + ) # [b, h, m, n, num_scores+1] - mix_attn_scores = ( + attn_scores = ( ( torch.matmul( F.relu( @@ -98,9 +70,45 @@ def forward(self, q_input, kv_input, dmat): .squeeze(-1) ) # [b, h, m, n] - attn_probs = F.softmax(mix_attn_scores, dim=-1) - out = torch.matmul(attn_probs, v) - return self.out_proj(rearrange(out, "b h s d -> b s (h d)")) + # Apply the provided attention mask + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_mask[~attn_mask.any(-1)] = True + attn_scores.masked_fill_(~attn_mask, float("-inf")) + else: + attn_scores += attn_mask + + # Softmax to get attention weights + attn_weights = F.softmax(attn_scores, dim=-1) + + # Apply dropout + if dropout_p > 0.0: + attn_weights = F.dropout(attn_weights, p=dropout_p) + + # Compute the weighted sum of values + return torch.matmul(attn_weights, v) + + +class MatNetCrossMHA(MultiHeadCrossAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = False, + mixer_hidden_dim: int = 16, + mix1_init: float = (1 / 2) ** (1 / 2), + mix2_init: float = (1 / 16) ** (1 / 2), + ): + attn_fn = MixedScoresSDPA( + num_heads=num_heads, + mixer_hidden_dim=mixer_hidden_dim, + mix1_init=mix1_init, + mix2_init=mix2_init, + ) + + super().__init__( + embed_dim=embed_dim, num_heads=num_heads, bias=bias, sdpa_fn=attn_fn + ) class MatNetMHA(nn.Module): @@ -109,7 +117,7 @@ def __init__(self, embed_dim: int, num_heads: int, bias: bool = False): self.row_encoding_block = MatNetCrossMHA(embed_dim, num_heads, bias) self.col_encoding_block = MatNetCrossMHA(embed_dim, num_heads, bias) - def forward(self, row_emb, col_emb, dmat): + def forward(self, row_emb, col_emb, dmat, attn_mask=None): """ Args: row_emb (Tensor): [b, m, d] @@ -120,10 +128,15 @@ def forward(self, row_emb, col_emb, dmat): Updated row_emb (Tensor): [b, m, d] Updated col_emb (Tensor): [b, n, d] """ - - updated_row_emb = self.row_encoding_block(row_emb, col_emb, dmat) + updated_row_emb = self.row_encoding_block( + row_emb, col_emb, dmat=dmat, cross_attn_mask=attn_mask + ) + attn_mask_t = attn_mask.transpose(-2, -1) if attn_mask is not None else None updated_col_emb = self.col_encoding_block( - col_emb, row_emb, dmat.transpose(-2, -1) + col_emb, + row_emb, + dmat=dmat.transpose(-2, -1), + cross_attn_mask=attn_mask_t, ) return updated_row_emb, updated_col_emb @@ -164,7 +177,7 @@ def __init__( } ) - def forward(self, row_emb, col_emb, dmat): + def forward(self, row_emb, col_emb, dmat, attn_mask=None): """ Args: row_emb (Tensor): [b, m, d] @@ -176,7 +189,7 @@ def forward(self, row_emb, col_emb, dmat): Updated col_emb (Tensor): [b, n, d] """ - row_emb_out, col_emb_out = self.MHA(row_emb, col_emb, dmat) + row_emb_out, col_emb_out = self.MHA(row_emb, col_emb, dmat, attn_mask) row_emb_out = self.F_a["norm1"](row_emb + row_emb_out) row_emb_out = self.F_a["norm2"](row_emb_out + self.F_a["ffn"](row_emb_out)) @@ -210,7 +223,7 @@ def __init__( ] ) - def forward(self, row_emb, col_emb, dmat): + def forward(self, row_emb, col_emb, dmat, attn_mask=None): """ Args: row_emb (Tensor): [b, m, d] @@ -223,7 +236,7 @@ def forward(self, row_emb, col_emb, dmat): """ for layer in self.layers: - row_emb, col_emb = layer(row_emb, col_emb, dmat) + row_emb, col_emb = layer(row_emb, col_emb, dmat, attn_mask) return row_emb, col_emb @@ -236,8 +249,9 @@ def __init__( normalization: str = "instance", feedforward_hidden: int = 512, init_embedding: nn.Module = None, - init_embedding_kwargs: dict = None, + init_embedding_kwargs: dict = {}, bias: bool = False, + mask_non_neighbors: bool = False, ): super().__init__() @@ -255,10 +269,16 @@ def __init__( feedforward_hidden=feedforward_hidden, bias=bias, ) + self.mask_non_neighbors = mask_non_neighbors - def forward(self, td): + def forward(self, td, attn_mask: torch.Tensor = None): row_emb, col_emb, dmat = self.init_embedding(td) - row_emb, col_emb = self.net(row_emb, col_emb, dmat) + + if self.mask_non_neighbors and attn_mask is None: + # attn_mask (keep 1s discard 0s) to only attend on neighborhood + attn_mask = dmat.ne(0) + + row_emb, col_emb = self.net(row_emb, col_emb, dmat, attn_mask) embedding = (row_emb, col_emb) init_embedding = None diff --git a/rl4co/models/zoo/matnet/matnet_w_sa.py b/rl4co/models/zoo/matnet/matnet_w_sa.py new file mode 100644 index 00000000..cf06056f --- /dev/null +++ b/rl4co/models/zoo/matnet/matnet_w_sa.py @@ -0,0 +1,202 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + +from rl4co.models.nn.attention import MultiHeadAttention +from rl4co.models.nn.env_embeddings import env_init_embedding +from rl4co.models.nn.ops import Normalization, TransformerFFN + + +def apply_weights_and_combine(dots, v, tanh_clipping=0): + # scale to avoid numerical underflow + logits = dots / dots.std() + if tanh_clipping > 0: + # tanh clipping to avoid explosions + logits = torch.tanh(logits) * tanh_clipping + # shape: (batch, num_heads, row_cnt, col_cnt) + weights = nn.Softmax(dim=-1)(logits) + weights = weights.nan_to_num(0) + # shape: (batch, num_heads, row_cnt, qkv_dim) + out = torch.matmul(weights, v) + # shape: (batch, row_cnt, num_heads, qkv_dim) + out = rearrange(out, "b h s d -> b s (h d)") + return out + + +class MixedScoreFF(nn.Module): + def __init__(self, num_heads, ms_hidden_dim: int = 32, bias: bool = False) -> None: + super().__init__() + + self.lin1 = nn.Linear(2 * num_heads, num_heads * ms_hidden_dim, bias=bias) + self.lin2 = nn.Linear(num_heads * ms_hidden_dim, num_heads, bias=bias) + + def forward(self, dot_product_score, cost_mat_score): + # dot_product_score shape: (batch, head_num, row_cnt, col_cnt) + # cost_mat_score shape: (batch, head_num, row_cnt, col_cnt) + # shape: (batch, head_num, row_cnt, col_cnt, 2) + two_scores = torch.stack((dot_product_score, cost_mat_score), dim=-1) + two_scores = rearrange(two_scores, "b h r c s -> b r c (h s)") + # shape: (batch, row_cnt, col_cnt, 2 * num_heads) + ms1 = self.lin1(two_scores) + ms1_activated = F.relu(ms1) + # shape: (batch, row_cnt, col_cnt, num_heads) + ms2 = self.lin2(ms1_activated) + # shape: (batch, row_cnt, head_num, col_cnt) + mixed_scores = rearrange(ms2, "b r c h -> b h r c") + + return mixed_scores + + +class EfficientMixedScoreMultiHeadAttention(nn.Module): + def __init__(self, embed_dim: int, num_heads: int, bias: bool = False): + super().__init__() + + qkv_dim = embed_dim // num_heads + + self.num_heads = num_heads + self.qkv_dim = qkv_dim + self.norm_factor = 1 / math.sqrt(qkv_dim) + + self.Wqv1 = nn.Linear(embed_dim, 2 * embed_dim, bias=bias) + self.Wkv2 = nn.Linear(embed_dim, 2 * embed_dim, bias=bias) + + # self.init_parameters() + self.mixed_scores_layer = MixedScoreFF(num_heads, qkv_dim, bias) + + self.out_proj1 = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj2 = nn.Linear(embed_dim, embed_dim, bias=bias) + + def forward(self, x1, x2, attn_mask=None, cost_mat=None): + batch_size = x1.size(0) + row_cnt = x1.size(-2) + col_cnt = x2.size(-2) + + # Project query, key, value + q, v1 = rearrange( + self.Wqv1(x1), "b s (two h d) -> two b h s d", two=2, h=self.num_heads + ).unbind(dim=0) + + # Project query, key, value + k, v2 = rearrange( + self.Wqv1(x2), "b s (two h d) -> two b h s d", two=2, h=self.num_heads + ).unbind(dim=0) + + # shape: (batch, num_heads, row_cnt, col_cnt) + dot = self.norm_factor * torch.matmul(q, k.transpose(-2, -1)) + + if cost_mat is not None: + # shape: (batch, num_heads, row_cnt, col_cnt) + cost_mat_score = cost_mat[:, None, :, :].expand_as(dot) + dot = self.mixed_scores_layer(dot, cost_mat_score) + + if attn_mask is not None: + attn_mask = attn_mask.view(batch_size, 1, row_cnt, col_cnt).expand_as(dot) + dot.masked_fill_(~attn_mask, float("-inf")) + + h1 = self.out_proj1(apply_weights_and_combine(dot, v2)) + h2 = self.out_proj2(apply_weights_and_combine(dot.transpose(-2, -1), v1)) + + return h1, h2 + + +class EncoderLayer(nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int = 8, + feedforward_hidden: int = 512, + normalization: str = "batch", + bias: bool = False, + ): + super().__init__() + + self.op_attn = MultiHeadAttention(embed_dim, num_heads, bias=bias) + self.ma_attn = MultiHeadAttention(embed_dim, num_heads, bias=bias) + self.cross_attn = EfficientMixedScoreMultiHeadAttention( + embed_dim, num_heads, bias=bias + ) + + self.op_ffn = TransformerFFN(embed_dim, feedforward_hidden, normalization) + self.ma_ffn = TransformerFFN(embed_dim, feedforward_hidden, normalization) + + self.op_norm = Normalization(embed_dim, normalization) + self.ma_norm = Normalization(embed_dim, normalization) + + def forward( + self, op_in, ma_in, cost_mat, op_mask=None, ma_mask=None, cross_mask=None + ): + op_cross_out, ma_cross_out = self.cross_attn( + op_in, ma_in, attn_mask=cross_mask, cost_mat=cost_mat + ) + op_cross_out = self.op_norm(op_cross_out + op_in) + ma_cross_out = self.ma_norm(ma_cross_out + ma_in) + + # (bs, num_jobs, ops_per_job, d) + op_self_out = self.op_attn(op_cross_out, attn_mask=op_mask) + # (bs, num_ma, d) + ma_self_out = self.ma_attn(ma_cross_out, attn_mask=ma_mask) + + op_out = self.op_ffn(op_cross_out, op_self_out) + ma_out = self.ma_ffn(ma_cross_out, ma_self_out) + + return op_out, ma_out + + +class Encoder(nn.Module): + def __init__( + self, + embed_dim: int = 256, + num_heads: int = 16, + num_layers: int = 5, + normalization: str = "batch", + feedforward_hidden: int = 512, + init_embedding: nn.Module = None, + init_embedding_kwargs: dict = {}, + bias: bool = False, + ): + super().__init__() + self.d_model = embed_dim + + if init_embedding is None: + init_embedding = env_init_embedding( + "matnet", {"embed_dim": embed_dim, **init_embedding_kwargs} + ) + self.init_embedding = init_embedding + self.layers = nn.ModuleList( + [ + EncoderLayer( + embed_dim=embed_dim, + num_heads=num_heads, + feedforward_hidden=feedforward_hidden, + normalization=normalization, + bias=bias, + ) + for _ in range(num_layers) + ] + ) + + def forward(self, td, attn_mask: torch.Tensor = None): + # [BS, num_machines, emb], [BS, num_operations, emb] + ops_embed, ma_embed, edge_feat = self.init_embedding(td) + try: + # mask padded ops; shape=(bs, ops) + ops_attn_mask = ~td["pad_mask"] + except KeyError: + ops_attn_mask = None + # padded ops should also be masked in cross attention; shape=(bs, ops, ma) + # cross_mask = ops_attn_mask.unsqueeze(-1).expand(-1, -1, ma_embed.size(1)) + for layer in self.layers: + ops_embed, ma_embed = layer( + ops_embed, + ma_embed, + cost_mat=edge_feat, + op_mask=ops_attn_mask, # mask padded operations in attention + ma_mask=None, # no padding for machines + cross_mask=None, + ) + embedding = (ops_embed, ma_embed) + return embedding, None diff --git a/rl4co/models/zoo/matnet/policy.py b/rl4co/models/zoo/matnet/policy.py index 24212310..0c2af426 100644 --- a/rl4co/models/zoo/matnet/policy.py +++ b/rl4co/models/zoo/matnet/policy.py @@ -141,7 +141,7 @@ def pre_forward(self, td: TensorDict, env: FFSPEnv, num_starts: int): encoder = self.encoders[stage_idx] embeddings, _ = encoder(td) decoder = self.decoders[stage_idx] - decoder._precompute_cache(embeddings, td) + decoder._precompute_cache(embeddings) if num_starts > 1: # repeat num_start times diff --git a/rl4co/tasks/train.py b/rl4co/tasks/train.py index 6826382d..b1d104b4 100644 --- a/rl4co/tasks/train.py +++ b/rl4co/tasks/train.py @@ -47,7 +47,7 @@ def run(cfg: DictConfig) -> Tuple[dict, dict]: callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) log.info("Instantiating loggers...") - logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) + logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"), model) log.info("Instantiating trainer...") trainer: RL4COTrainer = hydra.utils.instantiate( diff --git a/rl4co/utils/instantiators.py b/rl4co/utils/instantiators.py index e3b25183..9f4cdaf4 100644 --- a/rl4co/utils/instantiators.py +++ b/rl4co/utils/instantiators.py @@ -31,14 +31,14 @@ def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: return callbacks -def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: +def instantiate_loggers(logger_cfg: DictConfig, model) -> List[Logger]: """Instantiates loggers from config.""" - logger: List[Logger] = [] + logger_list: List[Logger] = [] if not logger_cfg: log.warning("No logger configs found! Skipping...") - return logger + return logger_list if not isinstance(logger_cfg, DictConfig): raise TypeError("Logger config must be a DictConfig!") @@ -46,6 +46,16 @@ def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: for _, lg_conf in logger_cfg.items(): if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: log.info(f"Instantiating logger <{lg_conf._target_}>") - logger.append(hydra.utils.instantiate(lg_conf)) - - return logger + if hasattr(lg_conf, "log_gradients"): + log_gradients = lg_conf.get("log_gradients", False) + # manually remove parameter, since pop doesnt work on DictConfig + del lg_conf.log_gradients + else: + log_gradients = False + logger = hydra.utils.instantiate(lg_conf) + if hasattr(logger, "watch") and log_gradients: + # make use of wandb gradient statistics logger + logger.watch(model, log_graph=False) + logger_list.append(logger) + + return logger_list diff --git a/rl4co/utils/ops.py b/rl4co/utils/ops.py index ce32aeef..86dc2a1e 100644 --- a/rl4co/utils/ops.py +++ b/rl4co/utils/ops.py @@ -73,7 +73,8 @@ def gather_by_index(src, idx, dim=1, squeeze=True): expanded_shape = list(src.shape) expanded_shape[dim] = -1 idx = idx.view(idx.shape + (1,) * (src.dim() - idx.dim())).expand(expanded_shape) - return src.gather(dim, idx).squeeze() if squeeze else src.gather(dim, idx) + squeeze = idx.size(dim) == 1 and squeeze + return src.gather(dim, idx).squeeze(dim) if squeeze else src.gather(dim, idx) def unbatchify_and_gather(x: Tensor, idx: Tensor, n: int): @@ -151,8 +152,8 @@ def select_start_nodes(td, env, num_starts): torch.arange(num_starts, device=td.device).repeat_interleave(td.shape[0]) % num_loc ) - elif env.name == "fjsp": - raise NotImplementedError("Multistart not yet supported for FJSP") + elif env.name in ["jssp", "fjsp"]: + raise NotImplementedError("Multistart not yet supported for FJSP/JSSP") else: # Environments with depot: we do not select the depot as a start node selected = ( @@ -221,6 +222,26 @@ def get_full_graph_edge_index(num_node: int, self_loop=False) -> Tensor: return edge_index +def adj_to_pyg_edge_index(adj: Tensor) -> Tensor: + """transforms an adjacency matrix (boolean) to a Tensor with the respective edge + indices (in the format required by the pytorch geometric module). + + :param Tensor adj: shape=(bs, num_nodes, num_nodes) + :return Tensor: shape=(2, num_edges) + """ + assert adj.size(1) == adj.size(2), "only symmetric adjacency matrices are supported" + num_nodes = adj.size(1) + # (num_edges, 3) + edge_idx = adj.nonzero() + batch_idx = edge_idx[:, 0] * num_nodes + # PyG expects a "single, flat graph", in which the graphs of the batch are not connected. + # Therefore, add the batch_idx to edge_idx to have unique indices + flat_edge_idx = edge_idx[:, 1:] + batch_idx[:, None] + # (2, num_edges) + flat_edge_idx = torch.permute(flat_edge_idx, (1, 0)) + return flat_edge_idx + + def sample_n_random_actions(td: TensorDict, n: int): """Helper function to sample n random actions from available actions. If number of valid actions is less then n, we sample with replacement from the diff --git a/tests/test_envs.py b/tests/test_envs.py index bd8d416b..775a46be 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -4,6 +4,8 @@ import pytest import torch +from tensordict import TensorDict + from rl4co.envs import ( ATSPEnv, CVRPEnv, @@ -11,6 +13,7 @@ DPPEnv, FFSPEnv, FJSPEnv, + JSSPEnv, MDCPDPEnv, MDPPEnv, MTSPEnv, @@ -46,7 +49,6 @@ MTSPEnv, ATSPEnv, MDCPDPEnv, - FJSPEnv, ], ) def test_routing(env_cls, batch_size=2, size=20): @@ -89,12 +91,12 @@ def test_eda(env_cls, batch_size=2, max_decaps=5): assert reward.shape == (batch_size,) -@pytest.mark.parametrize("env_cls", [FFSPEnv]) -def test_scheduling(env_cls, batch_size=2): +@pytest.mark.parametrize("env_cls", [FFSPEnv, FJSPEnv, JSSPEnv]) +@pytest.mark.parametrize("mask_no_ops", [True, False]) +def test_scheduling(env_cls, mask_no_ops, batch_size=2): env = env_cls() - td = env.reset(batch_size=[batch_size]) - td["action"] = torch.tensor([1, 1]) - td = env._step(td) + reward, td, actions = rollout(env, env.reset(batch_size=[batch_size]), random_policy) + assert reward.shape == (batch_size,) @pytest.mark.parametrize("env_cls", [SMTWTPEnv]) @@ -102,3 +104,45 @@ def test_smtwtp(env_cls, batch_size=2): env = env_cls(num_job=4) reward, td, actions = rollout(env, env.reset(batch_size=[batch_size]), random_policy) assert reward.shape == (batch_size,) + + +@pytest.mark.parametrize("env_cls", [JSSPEnv]) +def test_jssp_lb(env_cls): + env = env_cls(generator_params={"num_jobs": 2, "num_machines": 2}) + td = TensorDict( + { + "proc_times": torch.tensor( + [[[1, 0, 0, 4], [0, 2, 3, 0]]], dtype=torch.float32 + ), + "start_op_per_job": torch.tensor([[0, 2]], dtype=torch.long), + "end_op_per_job": torch.tensor([[1, 3]], dtype=torch.long), + "pad_mask": torch.tensor([[0, 0, 0, 0]], dtype=torch.bool), + }, + batch_size=[1], + ) + + td = env.reset(td) + + actions = [0, 1, 1] + for action in actions: + # NOTE add 1 to account for dummy action (waiting) + td.set("action", torch.tensor([action + 1], dtype=torch.long)) + td = env.step(td)["next"] + + lb_expected = torch.tensor([[1, 5, 3, 7]], dtype=torch.float32) + assert torch.allclose(td["lbs"], lb_expected) + + +def test_scheduling_dataloader(): + from tempfile import TemporaryDirectory + + from rl4co.envs.scheduling.fjsp.parser import write + + write_env = FJSPEnv() + + td = write_env.reset(batch_size=[2]) + with TemporaryDirectory() as tmpdirname: + write(tmpdirname, td) + read_env = FJSPEnv(generator_params={"file_path": tmpdirname}) + td = read_env.reset(batch_size=2) + assert td.size(0) == 2 diff --git a/tests/test_training.py b/tests/test_training.py index 89513ce4..1f249e55 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -3,7 +3,7 @@ import pytest -from rl4co.envs import ATSPEnv, PDPEnv, PDPRuinRepairEnv, TSPEnv +from rl4co.envs import ATSPEnv, FJSPEnv, JSSPEnv, PDPEnv, PDPRuinRepairEnv, TSPEnv from rl4co.models.rl import A2C, PPO, REINFORCE from rl4co.models.zoo import ( MDAM, @@ -14,6 +14,7 @@ EASEmb, EASLay, HeterogeneousAttentionModel, + L2DPPOModel, MatNet, NARGNNPolicy, SymNCO, @@ -177,3 +178,19 @@ def test_N2S(): ) trainer.fit(model) trainer.test(model) + + +@pytest.mark.parametrize("env_cls", [FJSPEnv, JSSPEnv]) +def test_l2d_ppo(env_cls): + env = env_cls(stepwise_reward=True, _torchrl_mode=True) + model = L2DPPOModel( + env, train_data_size=10, val_data_size=10, test_data_size=10, buffer_size=1000 + ) + trainer = RL4COTrainer( + max_epochs=1, + gradient_clip_val=0.05, + devices=1, + accelerator=accelerator, + ) + trainer.fit(model) + trainer.test(model)