diff --git a/alpa/collective/collective_group/base_collective_group.py b/alpa/collective/collective_group/base_collective_group.py index 8bfc2304a..4efdb249d 100644 --- a/alpa/collective/collective_group/base_collective_group.py +++ b/alpa/collective/collective_group/base_collective_group.py @@ -113,6 +113,7 @@ def get_access_counter(self): def destroy_store(self): """Delete the named actor.""" ray.kill(self._store) + # ray.get(self._store.__ray_terminate__.remote()) self._store = None diff --git a/alpa/collective/collective_group/nccl_collective_group.py b/alpa/collective/collective_group/nccl_collective_group.py index 460c8b94b..0c74edd18 100644 --- a/alpa/collective/collective_group/nccl_collective_group.py +++ b/alpa/collective/collective_group/nccl_collective_group.py @@ -725,7 +725,7 @@ def _rendezvous_nccl_uid(self, rank, comm_key, max_counter, nccl_uid=None): "NCCLUniqueID has been broadcasted. The " "NCCLUniqueIDStore will go out of context and be " "destroyed.") - rendezvous.destroy_store() + # rendezvous.destroy_store() return nccl_uid diff --git a/alpa/device_mesh.py b/alpa/device_mesh.py index 4991f81dc..b0b43c679 100644 --- a/alpa/device_mesh.py +++ b/alpa/device_mesh.py @@ -30,6 +30,7 @@ import time from typing import Any, List, Union, Sequence, Tuple, Optional +import jax from jax import core, xla, device_put from jax._src.api import ShapeDtypeStruct from jax._src.lib import xla_bridge as xb, xla_extension as xe @@ -203,6 +204,7 @@ def shard_and_put_non_zero_buffer(self, uuids: Union[Sequence[int], int], shard_shape.append(dim_size) arys[b][device_id] = (self.backend.buffer_from_pyval( np.full(shard_shape, 1e-8, dtype), + # np.random.normal(0, 0.0006, shard_shape).astype(dtype), self.local_devices[device_id])) for uuid, ary in zip(uuids, arys): self.buffers[uuid] = ary @@ -231,6 +233,34 @@ def get_buffers(self, for uuid, local_ids in zip(uuids, device_indices) ] + def copy_buffer(self, + shape, + src_indices, + dst_indices, + target_uuid, + src_uuid, + dtype): + # print(f"target_uuid: {target_uuid}, src_uuid: {src_uuid}...") + datas = self.buffers[src_uuid] + # print(f"old data: {datas} , size: {len(datas)}") + assert len(datas) == self.num_devices + assert len(datas) == len(src_indices) + assert len(datas) == len(dst_indices) + new_datas = [] + + if src_indices == dst_indices: + logger.debug("Indices are the same...") + for i, data in enumerate(datas): + new_datas.append(self.backend.buffer_from_pyval(np.array(data, dtype=dtype), data.device())) + else: + logger.debug("Indices are different... Resharding!") + src_array = np.zeros(shape, dtype=dtype) + for device_id, ind in enumerate(src_indices): + src_array[ind] = np.array(datas[device_id]) + for i, data in enumerate(datas): + new_datas.append(self.backend.buffer_from_pyval(src_array[dst_indices[i]], data.device())) + self.buffers[target_uuid] = new_datas + def delete_buffers(self, uuids: Union[Sequence[int], int]): if isinstance(uuids, Iterable): for uuid in uuids: @@ -1485,9 +1515,9 @@ class DistributedArray: a normal numpy array. Internally, it stores a pointer to all remote buffers. - The buffers are stored distributedly on remote workers' device memeory. + The buffers are stored distributedly on remote workers' device memory. When users require the value of the array. These buffers will be gathered - to the dirver. + to the driver. """ def __init__(self, @@ -1778,6 +1808,42 @@ def prefetch(dis_arrays: Sequence[Union[ShardedDeviceArray, DistributedArray, array._fetched_np_buffers = np_value # pylint: disable=protected-access +def copy_distributed_array(src_array: Union[DistributedArray, ReplicatedDistributedArray], + target_sharding_spec: ShardingSpec, + target_dtype: jnp.dtype): + aval = jax.core.ShapedArray(src_array.aval.shape, target_dtype) + if isinstance(src_array, DistributedArray): + mesh = src_array.device_mesh + src_spec = src_array.sharding_spec + ary_refs, ary_uuid = create_remote_array_refs(mesh) + dst_array = DistributedArray(mesh, aval, target_sharding_spec, ary_refs[0]) + if src_array.sharding_spec != target_sharding_spec: + print("Sharding spec changed. Will need resharding..." + f"src: {src_array.sharding_spec}, dst: {dst_array.sharding_spec}") + print(f"src_shape {src_array.aval.shape}, dst_shape {dst_array.aval.shape}, " + f"src_array_indices: {src_array.indices}, dst_array indices: {dst_array.indices}") + # Do actual copy + for w in mesh.workers: + w.copy_buffer.remote(dst_array.aval.shape, + src_array.indices, + dst_array.indices, + dst_array.remote_ref.uuid, + src_array.remote_ref.uuid, + target_dtype) + else: + assert isinstance(src_array, ReplicatedDistributedArray) + meshes = [] + arrays = [] + for mesh in src_array._mesh_array_map: + meshes.append(mesh) + ary = copy_distributed_array(src_array._mesh_array_map[mesh], + target_sharding_spec, + target_dtype) + arrays.append(ary) + dst_array = ReplicatedDistributedArray(meshes, arrays) + return dst_array + + ######################################## ##### Physical Mesh Group ##### ######################################## diff --git a/alpa/global_env.py b/alpa/global_env.py index d543ee673..839297aea 100644 --- a/alpa/global_env.py +++ b/alpa/global_env.py @@ -12,7 +12,7 @@ def __init__(self): # See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html self.xla_client_mem_fraction = float( - os.environ.get("XLA_PYTHON_CLIENT_MEM_FRACTION", 0.9)) + os.environ.get("XLA_PYTHON_CLIENT_MEM_FRACTION", 0.8)) self.xla_client_client_preallocate = os.environ.get( "XLA_PYTHON_CLIENT_PREALLOCATE", "true") # The threshold to tigger a batched deletion on workers. @@ -72,10 +72,10 @@ def __init__(self): self.use_local_allgather = True # Cross mesh resharding mode. Possible choices: {"send_recv", # "broadcast"} - self.resharding_mode = "send_recv" + self.resharding_mode = "broadcast" # Which nccl to use. Possible choices: {"cupy", # "xla_extension"} - self.nccl_mode = "cupy" + self.nccl_mode = "xla_extension" self.enable_overlapping = False # Cross mesh resharding load balancing mode. # Possible choices: {"normal", "no_loadbalance", diff --git a/alpa/model/model_util.py b/alpa/model/model_util.py index 7388d98c6..5103b0119 100644 --- a/alpa/model/model_util.py +++ b/alpa/model/model_util.py @@ -4,7 +4,9 @@ import functools from typing import Any, Callable, Optional, Tuple, Optional, Union, Sequence +import alpa.device_mesh from alpa.api import value_and_grad +from alpa.device_mesh import copy_distributed_array import flax from flax.training import train_state, dynamic_scale as dynamic_scale_lib from flax.training.dynamic_scale import DynamicScaleResult @@ -348,6 +350,30 @@ def create(cls, *, apply_fn, params, tx, use_master_copy=False, **kwargs): **kwargs, ) + @classmethod + def create_distributed(cls, *, apply_fn, params, tx, use_master_copy=False, **kwargs): + """The distributed version of create. It assumes the inputs are DistributedArrays.""" + if use_master_copy: + dtype = jax.tree_util.tree_flatten(params)[0][0].dtype + assert dtype == jnp.float16 + # create the master copy distributedly + master_copy = jax.tree_util.tree_map( + lambda x: copy_distributed_array(x, jnp.float32), params) + # TODO (Hao): handle opt_state + opt_state = tx.init(master_copy) + else: + master_copy = None + opt_state = tx.init(params) + return cls( + step=np.array(0, dtype=np.int32), + apply_fn=apply_fn, + params=params, + master_copy=master_copy, + tx=tx, + opt_state=opt_state, + **kwargs, + ) + @classmethod def create_aval(cls, *, @@ -377,6 +403,39 @@ def create_aval(cls, **kwargs, ) + @classmethod + def create_from(cls, + *, + train_state, + params, + use_master_copy=False, + **kwargs): + """Create a new instance where everything except master_copy is given.""" + if use_master_copy: + dtype = jax.tree_util.tree_flatten(params)[0][0].dtype + assert dtype == jnp.float16 + + def get_sharding_spec(array): + if isinstance(array, alpa.device_mesh.DistributedArray): + return array.sharding_spec + else: + assert isinstance(array, alpa.device_mesh.ReplicatedDistributedArray) + return array.replica.sharding_spec + + # create the master copy distributedly + master_copy = jax.tree_util.tree_map( + lambda x, y: copy_distributed_array(x, get_sharding_spec(y), jnp.float32), params, train_state.master_copy) + else: + master_copy = None + return cls( + step=train_state.step, + apply_fn=train_state.apply_fn, + params=params, + master_copy=master_copy, + tx=train_state.tx, + opt_state=train_state.opt_state, + **kwargs + ) class DynamicScale(struct.PyTreeNode): """This is the same as flax.optim.DynamicScale, except that diff --git a/alpa/parallel_method.py b/alpa/parallel_method.py index 3b92902b9..713ec684d 100644 --- a/alpa/parallel_method.py +++ b/alpa/parallel_method.py @@ -13,7 +13,7 @@ - PipeshardParallel: which combines pipeline parallelism and shard parallelism. """ from abc import ABC, abstractmethod -from typing import Callable, Optional, Sequence, Union, Any +from typing import Callable, Optional, Sequence, Union, Any, List from jax import linear_util as lu from jax._src import traceback_util @@ -248,9 +248,11 @@ def get_3d_parallel_method(num_micro_batches: int, data_parallel: int, operator_parallel: int, pipeline_parallel: int, - allow_degenerate_into_shard_parallel: bool = True): + allow_degenerate_into_shard_parallel: bool = True, + use_manual_layer_option: bool = False, + forward_stage_layer_ids: List[List[int]] = None): """ - Get a parallel method for 3D parallelism, which reguarlly combines + Get a parallel method for 3D parallelism, which regularly combines data parallelism, operator parallelism and pipeline parallelism. """ # Validity check @@ -259,6 +261,7 @@ def get_3d_parallel_method(num_micro_batches: int, num_devices_per_host = virtual_mesh.num_devices_per_host if data_parallel == -1: data_parallel = (num_devices // operator_parallel // pipeline_parallel) + print(f"num_devices {num_devices} dp {data_parallel}, op {operator_parallel}, pp {pipeline_parallel}") assert num_devices % data_parallel == 0 assert num_devices % operator_parallel == 0 assert num_devices % pipeline_parallel == 0 @@ -287,7 +290,15 @@ def get_3d_parallel_method(num_micro_batches: int, [data_parallel, operator_parallel])) # Return pipeshard parallel - layer_option = AutoLayerOption(layer_num=pp, eps=0.1) + if use_manual_layer_option: + # We assume each layer has been annotated using the mark_pipeline_boundary() + layer_option = ManualLayerOption() + assert forward_stage_layer_ids, "forward_stage_layer_ids must be provided " \ + "when using manual annotation." + else: + # Note: this eps need some tuning. + layer_option = AutoLayerOption(layer_num=pp, eps=0.1) + forward_stage_layer_ids = [[i] for i in range(pp)] return PipeshardParallel( devices=virtual_mesh, num_micro_batches=num_micro_batches, @@ -297,7 +308,7 @@ def get_3d_parallel_method(num_micro_batches: int, ), layer_option=layer_option, stage_option=ManualStageOption( - forward_stage_layer_ids=[[i] for i in range(pp)], + forward_stage_layer_ids=forward_stage_layer_ids, submesh_physical_shapes=[physical_mesh_shape] * pp, submesh_logical_shapes=[logical_mesh_shape] * pp, submesh_autosharding_option_dicts=[{}] * pp)) diff --git a/alpa/pipeline_parallel/computation.py b/alpa/pipeline_parallel/computation.py index 3dc2c7b68..feac7c966 100644 --- a/alpa/pipeline_parallel/computation.py +++ b/alpa/pipeline_parallel/computation.py @@ -28,7 +28,7 @@ get_compile_options, jaxpr_to_hlo, setup_computation_alias, compile_dummy_zero_constant, get_var_mapping, undefined_sharding_spec_proto, - new_jaxpr_eqn) + new_jaxpr_eqn, replicated_sharding_spec_proto) from alpa.wrapped_hlo import HloStatus, WrappedHlo # pylint: disable=redefined-builtin @@ -750,9 +750,13 @@ def generate_sharded_xla_computations_arguments( hlo.set_input_shardings(sharding_protos) if output_sharding_dict: - sharding_protos = [ - output_sharding_dict[x].sharding_proto() for x in outvars - ] + sharding_protos = [] + for x in outvars: + spec = output_sharding_dict.get(x, None) + if spec is None: + sharding_protos.append(replicated_sharding_spec_proto()) + else: + sharding_protos.append(spec.sharding_proto()) hlo.set_output_shardings(sharding_protos) if stage_input_sharding: diff --git a/alpa/util.py b/alpa/util.py index 72c58c3cb..1af4df5ad 100644 --- a/alpa/util.py +++ b/alpa/util.py @@ -614,6 +614,13 @@ def undefined_sharding_spec_proto(): return proto +def replicated_sharding_spec_proto(): + """Return a proto of ShardingSpec which represents a replicated spec.""" + proto = xc.OpSharding() + proto.type = xc.OpSharding.Type.REPLICATED + return proto + + ######################################## ##### Jaxpr Utilities ######################################## diff --git a/examples/llm_serving/model/opt_utils.py b/examples/llm_serving/model/opt_utils.py index 4fcd10e2c..a4a466f6c 100644 --- a/examples/llm_serving/model/opt_utils.py +++ b/examples/llm_serving/model/opt_utils.py @@ -4,7 +4,7 @@ from jax import xla, jit from jax.core import Primitive from jax._src.lib import xla_client as xc -from transformers.generation_utils import dataclass +from dataclasses import dataclass def sync(device_id=0): diff --git a/examples/llm_serving/model/wrapper.py b/examples/llm_serving/model/wrapper.py index a2629f9b8..ee4d2ffda 100644 --- a/examples/llm_serving/model/wrapper.py +++ b/examples/llm_serving/model/wrapper.py @@ -13,7 +13,9 @@ jax_index_select) from tqdm import tqdm from transformers import OPTForCausalLM, BloomForCausalLM -from transformers.generation_utils import GenerationMixin, ModelOutput, dataclass +from transformers import GenerationMixin +from transformers.utils import ModelOutput +from dataclasses import dataclass import alpa from alpa.device_mesh import DistributedArray diff --git a/examples/opt_finetune/config_125m.json b/examples/opt_finetune/config_125m.json new file mode 100644 index 000000000..08a88e353 --- /dev/null +++ b/examples/opt_finetune/config_125m.json @@ -0,0 +1,29 @@ +{ +"_name_or_path": "facebook/opt-125m", +"weight_path": "/home/ubuntu/dataset/opt_weights/125M_np", +"activation_dropout": 0, +"activation_function": "relu", +"architectures": [ +"OPTForCausalLM" +], +"attention_dropout": 0, +"bos_token_id": 2, +"do_layer_norm_before": true, +"dropout": 0.1, +"eos_token_id": 2, +"ffn_dim": 3072, +"hidden_size": 768, +"init_std": 0.02, +"layerdrop": 0, +"max_position_embeddings": 2048, +"model_type": "opt", +"num_attention_heads": 12, +"num_hidden_layers": 12, +"pad_token_id": 1, +"prefix": "", +"torch_dtype": "float16", +"transformers_version": "4.21.0.dev0", +"use_cache": true, +"vocab_size": 50272, +"word_embed_proj_dim": 768 +} \ No newline at end of file diff --git a/examples/opt_finetune/config_175b.json b/examples/opt_finetune/config_175b.json new file mode 100644 index 000000000..5ba97d66c --- /dev/null +++ b/examples/opt_finetune/config_175b.json @@ -0,0 +1,29 @@ +{ +"_name_or_path": "./", +"_remove_final_layer_norm": false, +"activation_dropout": 0, +"activation_function": "relu", +"architectures": [ +"OPTForCausalLM" +], +"attention_dropout": 0, +"bos_token_id": 2, +"do_layer_norm_before": true, +"dropout": 0.1, +"eos_token_id": 2, +"ffn_dim": 49152, +"hidden_size": 12288, +"init_std": 0.02, +"layerdrop": 0, +"max_position_embeddings": 2048, +"model_type": "opt", +"num_attention_heads": 96, +"num_hidden_layers": 96, +"pad_token_id": 1, +"prefix": "", +"torch_dtype": "float16", +"transformers_version": "4.21.0.dev0", +"use_cache": true, +"vocab_size": 50272, +"word_embed_proj_dim": 12288 +} \ No newline at end of file diff --git a/examples/opt_finetune/config_2.7b.json b/examples/opt_finetune/config_2.7b.json new file mode 100644 index 000000000..91aa0431b --- /dev/null +++ b/examples/opt_finetune/config_2.7b.json @@ -0,0 +1,29 @@ +{ +"weight_path": "/home/ubuntu/dataset/opt_weights/2.7B_np", +"_remove_final_layer_norm": false, +"activation_dropout": 0, +"activation_function": "relu", +"architectures": [ +"OPTForCausalLM" +], +"attention_dropout": 0, +"bos_token_id": 2, +"do_layer_norm_before": true, +"dropout": 0.1, +"eos_token_id": 2, +"ffn_dim": 10240, +"hidden_size": 2560, +"init_std": 0.02, +"layerdrop": 0, +"max_position_embeddings": 2048, +"model_type": "opt", +"num_attention_heads": 32, +"num_hidden_layers": 32, +"pad_token_id": 1, +"prefix": "", +"torch_dtype": "float16", +"transformers_version": "4.21.0.dev0", +"use_cache": true, +"vocab_size": 50272, +"word_embed_proj_dim": 2560 +} diff --git a/examples/opt_finetune/config_30b.json b/examples/opt_finetune/config_30b.json new file mode 100644 index 000000000..e6eb0cd1c --- /dev/null +++ b/examples/opt_finetune/config_30b.json @@ -0,0 +1,29 @@ +{ +"weight_path": "/home/ubuntu/dataset/opt_weights/30B_np", +"_remove_final_layer_norm": false, +"activation_dropout": 0, +"activation_function": "relu", +"architectures": [ +"OPTForCausalLM" +], +"attention_dropout": 0, +"bos_token_id": 2, +"do_layer_norm_before": true, +"dropout": 0.1, +"eos_token_id": 2, +"ffn_dim": 28672, +"hidden_size": 7168, +"init_std": 0.02, +"layerdrop": 0, +"max_position_embeddings": 2048, +"model_type": "opt", +"num_attention_heads": 56, +"num_hidden_layers": 48, +"pad_token_id": 1, +"prefix": "", +"torch_dtype": "float16", +"transformers_version": "4.21.0.dev0", +"use_cache": true, +"vocab_size": 50272, +"word_embed_proj_dim": 7168 +} \ No newline at end of file diff --git a/examples/opt_finetune/config_66b.json b/examples/opt_finetune/config_66b.json new file mode 100644 index 000000000..c171d64e8 --- /dev/null +++ b/examples/opt_finetune/config_66b.json @@ -0,0 +1,30 @@ +{ +"_name_or_path": "facebook/opt-66b", +"weight_path": "/home/ubuntu/tmp_dataset/opt_weights/66B_np", +"_remove_final_layer_norm": false, +"activation_dropout": 0, +"activation_function": "relu", +"architectures": [ +"OPTForCausalLM" +], +"attention_dropout": 0, +"bos_token_id": 2, +"do_layer_norm_before": true, +"dropout": 0.1, +"eos_token_id": 2, +"ffn_dim": 36864, +"hidden_size": 9216, +"init_std": 0.02, +"layerdrop": 0, +"max_position_embeddings": 2048, +"model_type": "opt", +"num_attention_heads": 72, +"num_hidden_layers": 64, +"pad_token_id": 1, +"prefix": "", +"torch_dtype": "float16", +"transformers_version": "4.21.0.dev0", +"use_cache": true, +"vocab_size": 50272, +"word_embed_proj_dim": 9216 +} \ No newline at end of file diff --git a/examples/opt_finetune/coreweave/opt_finetune_cuda113.Dockerfile b/examples/opt_finetune/coreweave/opt_finetune_cuda113.Dockerfile new file mode 100644 index 000000000..18cc8f7a3 --- /dev/null +++ b/examples/opt_finetune/coreweave/opt_finetune_cuda113.Dockerfile @@ -0,0 +1,92 @@ +# base docker image +FROM nvidia/cuda:11.3.0-cudnn8-devel-ubuntu20.04 + +# InfiniBand (IB) dependencies adopoted from CoreWeave's github +# https://github.com/coreweave/nccl-tests +ARG DEBIAN_FRONTEND=noninteractive +RUN apt-get -qq update && \ + apt-get -qq install -y --allow-change-held-packages --no-install-recommends \ + build-essential libtool autoconf automake autotools-dev unzip \ + ca-certificates \ + wget curl openssh-server vim environment-modules \ + iputils-ping net-tools \ + libnuma1 libsubunit0 libpci-dev \ + libpmix-dev \ + datacenter-gpu-manager + +# Mellanox OFED (latest) +RUN wget -qO - https://www.mellanox.com/downloads/ofed/RPM-GPG-KEY-Mellanox | apt-key add - +RUN cd /etc/apt/sources.list.d/ && wget https://linux.mellanox.com/public/repo/mlnx_ofed/latest/ubuntu18.04/mellanox_mlnx_ofed.list +RUN apt-get -qq update \ + && apt-get -qq install -y --no-install-recommends \ + ibverbs-utils libibverbs-dev libibumad3 libibumad-dev librdmacm-dev rdmacm-utils infiniband-diags ibverbs-utils \ + && rm -rf /var/lib/apt/lists/* + +# HPC-X (2.12) +ENV HPCX_VERSION=2.12 +RUN cd /tmp && \ + wget -q -O - http://blobstore.s3.ord1.coreweave.com/drivers/hpcx-v${HPCX_VERSION}-gcc-MLNX_OFED_LINUX-5-ubuntu20.04-cuda11-gdrcopy2-nccl${HPCX_VERSION}-x86_64.tbz | tar xjf - && \ + mv hpcx-v${HPCX_VERSION}-gcc-MLNX_OFED_LINUX-5-ubuntu20.04-cuda11-gdrcopy2-nccl${HPCX_VERSION}-x86_64 /opt/hpcx + +# GDRCopy userspace components (2.3) +RUN cd /tmp && \ + wget -q https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2011.4/x86/Ubuntu20.04/gdrcopy-tests_2.3-1_amd64.cuda11_4.Ubuntu20_04.deb && \ + wget -q https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2011.4/x86/Ubuntu20.04/libgdrapi_2.3-1_amd64.Ubuntu20_04.deb && \ + dpkg -i *.deb && \ + rm *.deb + +# Begin auto-generated paths +ENV HPCX_DIR=/opt/hpcx +ENV HPCX_UCX_DIR=/opt/hpcx/ucx +ENV HPCX_UCC_DIR=/opt/hpcx/ucc +ENV HPCX_SHARP_DIR=/opt/hpcx/sharp +ENV HPCX_NCCL_RDMA_SHARP_PLUGIN_DIR=/opt/hpcx/nccl_rdma_sharp_plugin +ENV HPCX_HCOLL_DIR=/opt/hpcx/hcoll +ENV HPCX_MPI_DIR=/opt/hpcx/ompi +ENV HPCX_OSHMEM_DIR=/opt/hpcx/ompi +ENV HPCX_MPI_TESTS_DIR=/opt/hpcx/ompi/tests +ENV HPCX_OSU_DIR=/opt/hpcx/ompi/tests/osu-micro-benchmarks-5.8 +ENV HPCX_OSU_CUDA_DIR=/opt/hpcx/ompi/tests/osu-micro-benchmarks-5.8-cuda +ENV HPCX_IPM_DIR=/opt/hpcx/ompi/tests/ipm-2.0.6 +ENV HPCX_CLUSTERKIT_DIR=/opt/hpcx/clusterkit +ENV OMPI_HOME=/opt/hpcx/ompi +ENV MPI_HOME=/opt/hpcx/ompi +ENV OSHMEM_HOME=/opt/hpcx/ompi +ENV OPAL_PREFIX=/opt/hpcx/ompi +ENV PATH=/opt/hpcx/clusterkit/bin:/opt/hpcx/hcoll/bin:/opt/hpcx/ucc/bin:/opt/hpcx/ucx/bin:/opt/hpcx/ompi/bin:$PATH +ENV LD_LIBRARY_PATH=/opt/hpcx/nccl_rdma_sharp_plugin/lib:/opt/hpcx/ucc/lib/ucc:/opt/hpcx/ucc/lib:/opt/hpcx/ucx/lib/ucx:/opt/hpcx/ucx/lib:/opt/hpcx/sharp/lib:/opt/hpcx/hcoll/lib:/opt/hpcx/ompi/lib:$LD_LIBRARY_PATH +ENV LIBRARY_PATH=/opt/hpcx/nccl_rdma_sharp_plugin/lib:/opt/hpcx/ompi/lib:/opt/hpcx/sharp/lib:/opt/hpcx/ucc/lib:/opt/hpcx/ucx/lib:/opt/hpcx/hcoll/lib:/opt/hpcx/ompi/lib:/usr/local/cuda/lib64/stubs +ENV CPATH=/opt/hpcx/ompi/include:/opt/hpcx/ucc/include:/opt/hpcx/ucx/include:/opt/hpcx/sharp/include:/opt/hpcx/hcoll/include: +ENV PKG_CONFIG_PATH=/opt/hpcx/hcoll/lib/pkgconfig:/opt/hpcx/sharp/lib/pkgconfig:/opt/hpcx/ucx/lib/pkgconfig:/opt/hpcx/ompi/lib/pkgconfig: +# End of auto-generated paths + +# install common tool & conda +RUN apt update && \ + apt install -y wget git vim screen && \ + wget --quiet https://repo.anaconda.com/archive/Anaconda3-2022.05-Linux-x86_64.sh -O ~/anaconda.sh && \ + /bin/bash ~/anaconda.sh -b -p /opt/conda && \ + rm ~/anaconda.sh && \ + mkdir -p /opt/conda/envs/alpa && \ + ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \ + echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \ + echo "conda activate base" >> ~/.bashrc + +# Some of my own dev config +RUN wget --quiet https://raw.githubusercontent.com/zhisbug/RC/master/.screenrc -P /root/ && \ + wget --quiet https://raw.githubusercontent.com/zhisbug/RC/master/.vimrc -P /root/ + +# install conda alpa env +RUN . /opt/conda/etc/profile.d/conda.sh && \ + conda create --name alpa python=3.8 -y && \ + conda activate alpa && \ + apt install coinor-cbc -y && \ + pip3 install --upgrade pip && \ + pip3 install cupy-cuda113 && \ + pip3 install alpa && \ + pip3 install jaxlib==0.3.22+cuda113.cudnn820 -f https://alpa-projects.github.io/wheels.html && \ + pip3 install datasets && \ + pip3 install transformers && \ + pip3 install tensorflow-gpu + +# Execute in Alpa conda env +ENV PATH /opt/conda/envs/alpa/bin:$PATH diff --git a/examples/opt_finetune/estimate_peak_memory.py b/examples/opt_finetune/estimate_peak_memory.py new file mode 100644 index 000000000..e46152f85 --- /dev/null +++ b/examples/opt_finetune/estimate_peak_memory.py @@ -0,0 +1,106 @@ +import dataclasses +from dataclasses import dataclass +import jax.numpy as jnp + +@dataclass(frozen=True) +class OPTConfig: + # Inherited from OPT + num_hidden_layers: int = 12 + max_seq_len: int = 2048 + hidden_size: int = 768 + n_head: int = 12 + input_dim: int = 768 + ffn_embed_dim: int = 3072 + pad: int = 1 + activation_fn: str = 'relu' + dtype: any = jnp.float16 + use_stable_embedding: bool = False + no_scale_embedding: bool = True + decoder_learned_pos: bool = True + decoder_normalize_before: bool = True + share_decoder_input_output_embed: bool = True + # Added + version: int = 1 + vocab_size: int = 50272 + layer_norm_eps: float = 0.00001 + num_pp_stages: int = None + # parallelize + mark_boundary: bool = True + + +def get_config(name, **kwargs): + if name == "opt-125m": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=12, n_head=12, + hidden_size=768, input_dim=768, ffn_embed_dim=768 * 4, + version=3, + ) + elif name == "opt-350m": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=24, n_head=16, + hidden_size=1024, input_dim=1024, ffn_embed_dim=1024 * 4, + version=2, + ) + raise NotImplementedError("Not implemented because this model " + "has a different architecture") + elif name == "opt-1.3b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=24, n_head=32, + hidden_size=2048, input_dim=2048, ffn_embed_dim=2048 * 4, + version=3, + ) + elif name == "opt-2.7b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=32, n_head=32, + hidden_size=2560, input_dim=2560, ffn_embed_dim=2560 * 4, + version=3, + ) + elif name == "opt-6.7b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=32, n_head=32, + hidden_size=4096, input_dim=4096, ffn_embed_dim=4096 * 4, + version=3, + ) + elif name == "opt-30b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=48, n_head=56, + hidden_size=7168, input_dim=7168, ffn_embed_dim=7168 * 4, + version=3, + ) + elif name == "opt-66b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=64, n_head=72, + hidden_size=9216, input_dim=9216, ffn_embed_dim=9216 * 4, + version=3, + ) + elif name == "opt-175b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=96, n_head=96, + hidden_size=12288, input_dim=12288, ffn_embed_dim=12288 * 4, + version=3, + ) + else: + raise ValueError(f"Invalid model name: {name}") + + return dataclasses.replace(config, **kwargs) + +def estimate_peak_memory(model_name: str, mbs: int, seq_len: int): + config = get_config(model_name) + n_layer = config.num_hidden_layers + n_head = config.n_head + h = config.hidden_size + n_params = float(model_name.split("-")[-1][:-1]) * 1e9 + + n_bytes = 16 * n_params + 2 * n_layer * mbs * seq_len * h + \ + seq_len * mbs * h * (34 + 5.0 * n_head * seq_len / h) + mem_gb = n_bytes / 1e9 + print(f"model: {model_name}, micro bs: {mbs}, seq_len: {seq_len}, memory lower bound: {mem_gb} GB.") + return mem_gb + + +# estimate_peak_memory("opt-2.7b", 2, 1024) +estimate_peak_memory("opt-30b", 2, 1024) +estimate_peak_memory("opt-175b", 2, 1024) +estimate_peak_memory("opt-66b", 2, 1024) +estimate_peak_memory("opt-66b", 8, 1024) +estimate_peak_memory("opt-66b", 16, 1024) \ No newline at end of file diff --git a/examples/opt_finetune/estimate_throughput.py b/examples/opt_finetune/estimate_throughput.py new file mode 100644 index 000000000..383a9fb50 --- /dev/null +++ b/examples/opt_finetune/estimate_throughput.py @@ -0,0 +1,105 @@ +import dataclasses +from dataclasses import dataclass +import jax.numpy as jnp +import alpa + +@dataclass(frozen=True) +class OPTConfig: + # Inherited from OPT + num_hidden_layers: int = 12 + max_seq_len: int = 2048 + hidden_size: int = 768 + n_head: int = 12 + input_dim: int = 768 + ffn_embed_dim: int = 3072 + pad: int = 1 + activation_fn: str = 'relu' + dtype: any = jnp.float16 + use_stable_embedding: bool = False + no_scale_embedding: bool = True + decoder_learned_pos: bool = True + decoder_normalize_before: bool = True + share_decoder_input_output_embed: bool = True + # Added + version: int = 1 + vocab_size: int = 50272 + layer_norm_eps: float = 0.00001 + num_pp_stages: int = None + # parallelize + mark_boundary: bool = True + + +def get_config(name, **kwargs): + if name == "opt-125m": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=12, n_head=12, + hidden_size=768, input_dim=768, ffn_embed_dim=768 * 4, + version=3, + ) + elif name == "opt-350m": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=24, n_head=16, + hidden_size=1024, input_dim=1024, ffn_embed_dim=1024 * 4, + version=2, + ) + raise NotImplementedError("Not implemented because this model " + "has a different architecture") + elif name == "opt-1.3b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=24, n_head=32, + hidden_size=2048, input_dim=2048, ffn_embed_dim=2048 * 4, + version=3, + ) + elif name == "opt-2.7b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=32, n_head=32, + hidden_size=2560, input_dim=2560, ffn_embed_dim=2560 * 4, + version=3, + ) + elif name == "opt-6.7b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=32, n_head=32, + hidden_size=4096, input_dim=4096, ffn_embed_dim=4096 * 4, + version=3, + ) + elif name == "opt-30b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=48, n_head=56, + hidden_size=7168, input_dim=7168, ffn_embed_dim=7168 * 4, + version=3, + ) + elif name == "opt-66b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=64, n_head=72, + hidden_size=9216, input_dim=9216, ffn_embed_dim=9216 * 4, + version=3, + ) + elif name == "opt-175b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=96, n_head=96, + hidden_size=12288, input_dim=12288, ffn_embed_dim=12288 * 4, + version=3, + ) + else: + raise ValueError(f"Invalid model name: {name}") + + return dataclasses.replace(config, **kwargs) + + +def estimate_throughput(model_name: str, gbs: int, seq_len: int, num_device: int, latency: float): + config = get_config(model_name) + n_layer = config.num_hidden_layers + h = config.hidden_size + + throughput_tflops = alpa.util.compute_gpt_tflops( + batch_size=gbs, + seq_len=seq_len, + num_layers=n_layer, + hidden_size=h, + vocab_size=50272, + num_gpus=num_device, + latency=latency) + print(f"Model {model_name}, gbs: {gbs}, seq_len: {seq_len}, Model TFlops: {throughput_tflops}, HW TFlops: {throughput_tflops * 4 / 3}..") + + +estimate_throughput("opt-175b", 1536, 2048, 128, 207) \ No newline at end of file diff --git a/examples/opt_finetune/load_params.py b/examples/opt_finetune/load_params.py new file mode 100644 index 000000000..1e60e10b3 --- /dev/null +++ b/examples/opt_finetune/load_params.py @@ -0,0 +1,331 @@ +import os +import itertools +import time + +import numpy as np +import alpa +from alpa.device_mesh import (DistributedArray, ReplicatedDistributedArray, + MeshHostWorker, create_remote_array_refs) +import jax +import flax +from jax.tree_util import tree_flatten, tree_unflatten, tree_leaves +from jax.interpreters import pxla + + +def load_opt_params_worker_func_66b(self, path, prefix_to_idx, config, shapes, + uuids, indices, mesh_ids): + """The worker function to load OPT parameters.""" + + def load_array(key): + return np.load(os.path.join(path, key)) + + def load_param(param_key, loaded_array, is_position_embedding=False): + i = prefix_to_idx[param_key] + + for j in range(len(mesh_ids[i])): + if self.mesh_id != mesh_ids[i][j]: + continue + + if not is_position_embedding: + assert shapes[i][j] == loaded_array.shape, ( + f"{shapes[i][j]} vs. {loaded_array.shape}") + else: + if shapes[i][j] != loaded_array.shape: + assert shapes[i][j][1] == loaded_array.shape[1] + loaded_array = loaded_array[:shapes[i][j][0], :] + uuid = uuids[i][j] + datas = [] + for k in range(len(self.local_devices)): + idx = self.host_id * len(self.local_devices) + k + datas.append(loaded_array[indices[i][j][idx]]) + self.put_buffers(uuid, datas) + + # print(f" === Reading on {self.mesh_id} ===") + tic = time.time() + load_param("model.decoder.embed_tokens.embedding", + load_array("decoder.embed_tokens.embedding")) + load_param("model.decoder.embed_positions.embedding", + load_array("decoder.embed_positions.embedding"), + is_position_embedding=True) + + # if config.version > 2: + load_param("model.decoder.final_layer_norm.scale", + load_array("decoder.final_layer_norm.scale")) + load_param("model.decoder.final_layer_norm.bias", + load_array("decoder.final_layer_norm.bias")) + + layers_per_stage = config.num_hidden_layers // config.pp + + for i in range(config.num_hidden_layers): + stage_id = i // layers_per_stage + if stage_id != self.mesh_id: + continue + + # print(f"===>Reading layer {i} for mesh {self.mesh_id}") + tic1 = time.time() + param_prefix = f"model.decoder.layers.{i}." + load_prefix = f"decoder.layers.{i}." + # Attention weights + wq = load_array(load_prefix + "self_attn.q_proj.kernel") + wk = load_array(load_prefix + "self_attn.k_proj.kernel") + wv = load_array(load_prefix + "self_attn.v_proj.kernel") + # dim = wq.shape[-1] + # w_qkv = np.concatenate([wq, wk, wv], axis=0).reshape( + # (3, -1, dim)).transpose([2, 1, 0]).reshape((dim, -1)) + # load_param(param_prefix + "attention.self.qkv_combined.kernel", w_qkv) + load_param(param_prefix + "self_attn.q_proj.kernel", wq) + load_param(param_prefix + "self_attn.k_proj.kernel", wk) + load_param(param_prefix + "self_attn.v_proj.kernel", wv) + + + bq = load_array(load_prefix + "self_attn.q_proj.bias") + bk = load_array(load_prefix + "self_attn.k_proj.bias") + bv = load_array(load_prefix + "self_attn.v_proj.bias") + # b_qkv = np.concatenate([bq, bk, bv], axis=0).reshape( + # (3, dim)).transpose([1, 0]).reshape((-1,)) + load_param(param_prefix + "self_attn.q_proj.bias", bq) + load_param(param_prefix + "self_attn.k_proj.bias", bk) + load_param(param_prefix + "self_attn.v_proj.bias", bv) + # load_param(param_prefix + "attention.self.qkv_combined.bias", b_qkv) + load_param( + param_prefix + "self_attn.out_proj.kernel", + load_array(load_prefix + "self_attn.out_proj.kernel")) + load_param(param_prefix + "self_attn.out_proj.bias", + load_array(load_prefix + "self_attn.out_proj.bias")) + load_param(param_prefix + "self_attn_layer_norm.scale", + load_array(load_prefix + "self_attn_layer_norm.scale")) + load_param(param_prefix + "self_attn_layer_norm.bias", + load_array(load_prefix + "self_attn_layer_norm.bias")) + # FFN weights + load_param(param_prefix + "fc1.bias", + load_array(load_prefix + "fc1.bias")) + load_param(param_prefix + "fc1.kernel", + load_array(load_prefix + "fc1.kernel")) + load_param(param_prefix + "fc2.bias", + load_array(load_prefix + "fc2.bias")) + load_param(param_prefix + "fc2.kernel", + load_array(load_prefix + "fc2.kernel")) + load_param(param_prefix + "final_layer_norm.scale", + load_array(load_prefix + "final_layer_norm.scale")) + load_param(param_prefix + "final_layer_norm.bias", + load_array(load_prefix + "final_layer_norm.bias")) + # print(f"===>Reading layer {i} for mesh {self.mesh_id} done: {time.time() - tic1}") + print(f" === Reading on {self.mesh_id} done: {time.time() - tic} ===") + + +def load_opt_params_worker_func(self, path, prefix_to_idx, config, shapes, + uuids, indices, mesh_ids): + """The worker function to load OPT parameters.""" + + def load_array(key): + return np.load(os.path.join(path, key)) + + def load_param(param_key, loaded_array, is_position_embedding=False): + i = prefix_to_idx[param_key] + + for j in range(len(mesh_ids[i])): + if self.mesh_id != mesh_ids[i][j]: + continue + + if not is_position_embedding: + assert shapes[i][j] == loaded_array.shape, ( + f"{shapes[i][j]} vs. {loaded_array.shape}") + else: + if shapes[i][j] != loaded_array.shape: + assert shapes[i][j][1] == loaded_array.shape[1] + loaded_array = loaded_array[:shapes[i][j][0], :] + uuid = uuids[i][j] + datas = [] + for k in range(len(self.local_devices)): + idx = self.host_id * len(self.local_devices) + k + datas.append(loaded_array[indices[i][j][idx]]) + self.put_buffers(uuid, datas) + + # print(f" === Reading on {self.mesh_id} ===") + tic = time.time() + load_param("model.decoder.embed_tokens.embedding", + load_array("decoder.embed_tokens.weight")) + load_param("model.decoder.embed_positions.embedding", + load_array("decoder.embed_positions.weight"), + is_position_embedding=True) + + # if config.version > 2: + load_param("model.decoder.final_layer_norm.scale", + load_array("decoder.layer_norm.weight")) + load_param("model.decoder.final_layer_norm.bias", + load_array("decoder.layer_norm.bias")) + + layers_per_stage = config.num_hidden_layers // config.pp + + for i in range(config.num_hidden_layers): + stage_id = i // layers_per_stage + if stage_id != self.mesh_id: + continue + + # print(f"===>Reading layer {i} for mesh {self.mesh_id}") + tic1 = time.time() + param_prefix = f"model.decoder.layers.{i}." + load_prefix = f"decoder.layers.{i}." + # Attention weights + wq = load_array(load_prefix + "self_attn.q_proj.weight") + wk = load_array(load_prefix + "self_attn.k_proj.weight") + wv = load_array(load_prefix + "self_attn.v_proj.weight") + # dim = wq.shape[-1] + # w_qkv = np.concatenate([wq, wk, wv], axis=0).reshape( + # (3, -1, dim)).transpose([2, 1, 0]).reshape((dim, -1)) + # load_param(param_prefix + "attention.self.qkv_combined.kernel", w_qkv) + load_param(param_prefix + "self_attn.q_proj.kernel", wq.T) + load_param(param_prefix + "self_attn.k_proj.kernel", wk.T) + load_param(param_prefix + "self_attn.v_proj.kernel", wv.T) + + + bq = load_array(load_prefix + "self_attn.q_proj.bias") + bk = load_array(load_prefix + "self_attn.k_proj.bias") + bv = load_array(load_prefix + "self_attn.v_proj.bias") + # b_qkv = np.concatenate([bq, bk, bv], axis=0).reshape( + # (3, dim)).transpose([1, 0]).reshape((-1,)) + load_param(param_prefix + "self_attn.q_proj.bias", bq) + load_param(param_prefix + "self_attn.k_proj.bias", bk) + load_param(param_prefix + "self_attn.v_proj.bias", bv) + # load_param(param_prefix + "attention.self.qkv_combined.bias", b_qkv) + load_param( + param_prefix + "self_attn.out_proj.kernel", + np.transpose(load_array(load_prefix + "self_attn.out_proj.weight"))) + load_param(param_prefix + "self_attn.out_proj.bias", + load_array(load_prefix + "self_attn.out_proj.bias")) + load_param(param_prefix + "self_attn_layer_norm.scale", + load_array(load_prefix + "self_attn_layer_norm.weight")) + load_param(param_prefix + "self_attn_layer_norm.bias", + load_array(load_prefix + "self_attn_layer_norm.bias")) + # FFN weights + load_param(param_prefix + "fc1.bias", + load_array(load_prefix + "fc1.bias")) + load_param(param_prefix + "fc1.kernel", + np.transpose(load_array(load_prefix + "fc1.weight"))) + load_param(param_prefix + "fc2.bias", + load_array(load_prefix + "fc2.bias")) + load_param(param_prefix + "fc2.kernel", + np.transpose(load_array(load_prefix + "fc2.weight"))) + load_param(param_prefix + "final_layer_norm.scale", + load_array(load_prefix + "final_layer_norm.weight")) + load_param(param_prefix + "final_layer_norm.bias", + load_array(load_prefix + "final_layer_norm.bias")) + # print(f"===>Reading layer {i} for mesh {self.mesh_id} done: {time.time() - tic1}") + print(f" === Reading on {self.mesh_id} done: {time.time() - tic} ===") + +setattr(MeshHostWorker, "load_opt_params_worker_func", + load_opt_params_worker_func) +# setattr(MeshHostWorker, "load_opt_params_worker_func", +# load_opt_params_worker_func_66b) + + +def load_params_dis_array(path, executable, params_aval, config, dummy=False): + """Load parameters with distributed arrays.""" + if dummy: + alpa.global_config.use_dummy_value_for_benchmarking = True + params_info, _ = executable.get_input_placement_specs() + flat_args, in_tree = tree_flatten(params_aval) + flat_info = tree_leaves(params_info) + if hasattr(executable, "mesh_group"): + ret = executable.mesh_group.shard_args_to_arrays( + flat_info, flat_args) + else: + ret = executable.physical_mesh.shard_args_to_arrays_ps( + flat_info, flat_args) + alpa.global_config.use_dummy_value_for_benchmarking = False + return ret + + params_info, _ = executable.get_input_placement_specs() + params_info = params_info.params + + prefix_to_flat_idx = {} + ct = itertools.count() + + def dfs(dict_tree, result_dict, cur_prefix): + if isinstance(dict_tree, (dict, flax.core.FrozenDict)): + for key in dict_tree.keys(): + dfs(dict_tree[key], result_dict, + cur_prefix + ("." if cur_prefix else "") + key) + else: + result_dict[cur_prefix] = next(ct) + + dfs(params_aval, prefix_to_flat_idx, "") + + flat_infos, in_tree = tree_flatten(params_info) + + flat_shapes = [] + flat_uuids = [] + flat_indices = [] + flat_mesh_ids = [] + flat_arrays = [] + + mesh_group = executable.mesh_group + + for info in flat_infos: + aval = info.aval + if len(info.mesh_ids) == 1: + mesh, spec = mesh_group[info.mesh_ids[0]], info.sharding_specs[0] + indices = pxla.spec_to_indices(aval.shape, spec) + ary_refs, ary_uuid = create_remote_array_refs(mesh) + flat_shapes.append([aval.shape]) + flat_uuids.append([ary_uuid[0]]) + flat_indices.append([indices]) + flat_mesh_ids.append([mesh.mesh_id]) + flat_arrays.append( + DistributedArray(mesh, aval, spec, ary_refs[0], indices)) + else: + tmp_shapes = [] + tmp_uuids = [] + tmp_indices = [] + tmp_mesh_ids = [] + tmp_arrays = [] + tmp_meshes = [] + for mesh_id, spec in zip(info.mesh_ids, info.sharding_specs): + mesh = mesh_group[mesh_id] + indices = pxla.spec_to_indices(aval.shape, spec) + ary_refs, ary_uuid = create_remote_array_refs(mesh) + array = DistributedArray(mesh, aval, spec, ary_refs[0], indices) + tmp_shapes.append(aval.shape) + tmp_uuids.append(ary_uuid[0]) + tmp_indices.append(indices) + tmp_mesh_ids.append(mesh.mesh_id) + tmp_meshes.append(mesh) + tmp_arrays.append(array) + flat_shapes.append(tuple(tmp_shapes)) + flat_uuids.append(tuple(tmp_uuids)) + flat_indices.append(tuple(tmp_indices)) + flat_mesh_ids.append(tuple(tmp_mesh_ids)) + flat_arrays.append( + ReplicatedDistributedArray(tmp_meshes, tmp_arrays)) + + for m in executable.mesh_group.meshes: + for w in m.workers: + w.load_opt_params_worker_func.remote(path, prefix_to_flat_idx, + config, flat_shapes, + flat_uuids, flat_indices, + flat_mesh_ids) + + return tree_unflatten(in_tree, flat_arrays) + # return flat_arrays + + +def load_multi_executable_params_dis_array(path, + executables, + params_aval, + config, + dummy=False): + """Load parameters to workers that will be used by all executables. Accordingly, + we need to make sure the parameter sharding specs are identical for all executables. + """ + shared_input_shard_specs = None + # for executable in executables.values(): + # # stage_input_shard_specs = executable.stage_input_shard_specs + # stage_input_shard_specs = executable.get_input_placement_specs + # if shared_input_shard_specs is not None: + # assert shared_input_shard_specs == stage_input_shard_specs, \ + # "All executables must have the same input sharding specs." + # else: + # shared_input_shard_specs = stage_input_shard_specs + return load_params_dis_array(path, + list(executables.values())[0], params_aval, + config, dummy) diff --git a/examples/opt_finetune/requirements.txt b/examples/opt_finetune/requirements.txt new file mode 100644 index 000000000..3b8901a82 --- /dev/null +++ b/examples/opt_finetune/requirements.txt @@ -0,0 +1,3 @@ +datasets >= 1.1.3 +transformers +tensorflow-gpu diff --git a/examples/opt_finetune/run_125m_pipe.sh b/examples/opt_finetune/run_125m_pipe.sh new file mode 100644 index 000000000..f34cea180 --- /dev/null +++ b/examples/opt_finetune/run_125m_pipe.sh @@ -0,0 +1,23 @@ +python3 run_clm_flax.py \ + --output_dir="./output" \ + --config_name="./config_125m.json" \ + --tokenizer_name="facebook/opt-30b" \ + --alpa_init \ + --use_manual_layer \ + --dataset_name="wikitext" \ + --dataset_config_name="wikitext-2-raw-v1" \ + --do_train \ + --block_size="1024" \ + --per_device_train_batch_size="64" \ + --per_device_eval_batch_size="20" \ + --num_micro_batches 4 \ + --operator_parallel 1 \ + --pipeline_parallel 4 \ + --dtype="float16" \ + --learning_rate="5e-4" --warmup_steps="2000" \ + --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ + --overwrite_output_dir \ + --num_train_epochs="8" \ + --logging_steps="1" \ + --save_steps="888" \ + --eval_steps="888" diff --git a/examples/opt_finetune/run_125m_shard.sh b/examples/opt_finetune/run_125m_shard.sh index 22b32f64a..1d923dedb 100644 --- a/examples/opt_finetune/run_125m_shard.sh +++ b/examples/opt_finetune/run_125m_shard.sh @@ -8,7 +8,7 @@ python3 run_clm_flax.py \ --per_device_train_batch_size="20" \ --per_device_eval_batch_size="20" \ --num_micro_batches 4 \ - --operator_parallel 4 \ + --operator_parallel 2 \ --pipeline_parallel 1 \ --dtype="float16" \ --learning_rate="5e-4" --warmup_steps="2000" \ diff --git a/examples/opt_finetune/run_175b_pipe.sh b/examples/opt_finetune/run_175b_pipe.sh new file mode 100644 index 000000000..40d7f0554 --- /dev/null +++ b/examples/opt_finetune/run_175b_pipe.sh @@ -0,0 +1,24 @@ +python3 run_clm_flax.py \ + --output_dir="./output" \ + --config_name="./config_175b.json" \ + --tokenizer_name="facebook/opt-30b" \ + --alpa_init \ + --use_manual_layer \ + --use_dummy_value \ + --dataset_name="wikitext" \ + --dataset_config_name="wikitext-2-raw-v1" \ + --do_train \ + --block_size="2048" \ + --per_device_train_batch_size="1536" \ + --per_device_eval_batch_size="64" \ + --num_micro_batches 768 \ + --operator_parallel 4 \ + --pipeline_parallel 16 \ + --dtype="float16" \ + --learning_rate="5e-4" --warmup_steps="2000" \ + --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ + --overwrite_output_dir \ + --num_train_epochs="10" \ + --logging_steps="1" \ + --save_steps="888" \ + --eval_steps="888" diff --git a/examples/opt_finetune/run_2.7b_pipe.sh b/examples/opt_finetune/run_2.7b_pipe.sh index 4fe9ae47a..31b811d98 100644 --- a/examples/opt_finetune/run_2.7b_pipe.sh +++ b/examples/opt_finetune/run_2.7b_pipe.sh @@ -1,20 +1,23 @@ python3 run_clm_flax.py \ --output_dir="./output" \ - --model_name_or_path="facebook/opt-2.7b" \ + --config_name="./config_2.7b.json" \ + --tokenizer_name="facebook/opt-30b" \ + --alpa_init \ + --use_manual_layer \ --dataset_name="wikitext" \ --dataset_config_name="wikitext-2-raw-v1" \ - --do_train --do_eval \ + --do_train \ --block_size="1024" \ - --per_device_train_batch_size="64" \ + --per_device_train_batch_size="128" \ --per_device_eval_batch_size="64" \ - --num_micro_batches 64 \ + --num_micro_batches 16 \ --operator_parallel 1 \ - --pipeline_parallel 2 \ + --pipeline_parallel 4 \ --dtype="float16" \ --learning_rate="5e-4" --warmup_steps="2000" \ --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ --overwrite_output_dir \ --num_train_epochs="10" \ - --logging_steps="5" \ - --save_steps="40" \ - --eval_steps="25" + --logging_steps="1" \ + --save_steps="888" \ + --eval_steps="888" diff --git a/examples/opt_finetune/run_2.7b_shard.sh b/examples/opt_finetune/run_2.7b_shard.sh index 33f46e411..1749e82ea 100644 --- a/examples/opt_finetune/run_2.7b_shard.sh +++ b/examples/opt_finetune/run_2.7b_shard.sh @@ -8,7 +8,7 @@ python3 run_clm_flax.py \ --per_device_train_batch_size="20" \ --per_device_eval_batch_size="20" \ --num_micro_batches 4 \ - --operator_parallel 4 \ + --operator_parallel 2 \ --pipeline_parallel 1 \ --dtype="float16" \ --learning_rate="5e-4" --warmup_steps="2000" \ diff --git a/examples/opt_finetune/run_30b_pipe.sh b/examples/opt_finetune/run_30b_pipe.sh new file mode 100644 index 000000000..cb389f083 --- /dev/null +++ b/examples/opt_finetune/run_30b_pipe.sh @@ -0,0 +1,23 @@ +python3 run_clm_flax.py \ + --output_dir="./output" \ + --config_name="./config_30b.json" \ + --tokenizer_name="facebook/opt-30b" \ + --alpa_init \ + --use_manual_layer \ + --dataset_name="wikitext" \ + --dataset_config_name="wikitext-2-raw-v1" \ + --do_train \ + --block_size="1024" \ + --per_device_train_batch_size="1024" \ + --per_device_eval_batch_size="64" \ + --num_micro_batches 256 \ + --operator_parallel 1 \ + --pipeline_parallel 16 \ + --dtype="float16" \ + --learning_rate="5e-4" --warmup_steps="2000" \ + --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ + --overwrite_output_dir \ + --num_train_epochs="10" \ + --logging_steps="1" \ + --save_steps="888" \ + --eval_steps="888" diff --git a/examples/opt_finetune/run_66b_pipe.sh b/examples/opt_finetune/run_66b_pipe.sh new file mode 100644 index 000000000..2990026c4 --- /dev/null +++ b/examples/opt_finetune/run_66b_pipe.sh @@ -0,0 +1,24 @@ +python3 run_clm_flax.py \ + --output_dir="./output" \ + --config_name="./config_66b.json" \ + --tokenizer_name="facebook/opt-30b" \ + --alpa_init \ + --use_manual_layer \ + --use_dummy_value \ + --dataset_name="wikitext" \ + --dataset_config_name="wikitext-2-raw-v1" \ + --do_train \ + --block_size="1024" \ + --per_device_train_batch_size="1024" \ + --per_device_eval_batch_size="64" \ + --num_micro_batches 512 \ + --operator_parallel 4 \ + --pipeline_parallel 8 \ + --dtype="float16" \ + --learning_rate="5e-4" --warmup_steps="2000" \ + --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ + --overwrite_output_dir \ + --num_train_epochs="10" \ + --logging_steps="1" \ + --save_steps="888" \ + --eval_steps="888" diff --git a/examples/opt_finetune/run_clm_flax.py b/examples/opt_finetune/run_clm_flax.py index d72961aea..b1fb26f08 100644 --- a/examples/opt_finetune/run_clm_flax.py +++ b/examples/opt_finetune/run_clm_flax.py @@ -32,7 +32,8 @@ import functools from itertools import chain from pathlib import Path -from typing import Callable, Optional +from typing import Callable, Optional, Tuple, Dict +import copy import datasets import numpy as np @@ -41,7 +42,8 @@ import alpa from alpa.model.model_util import DynamicScale, TrainState -from alpa import AutoShardingOption, AutoLayerOption, ManualStageOption +from alpa import AutoShardingOption, AutoLayerOption, ManualStageOption, mark_pipeline_boundary, CreateStateParallel +from alpa import global_config import jax import jax.numpy as jnp import optax @@ -50,6 +52,7 @@ from flax import jax_utils, traverse_util from flax.training import train_state from flax.training.common_utils import onehot, shard, shard_prng_key +from flax.core.frozen_dict import unfreeze, FrozenDict from huggingface_hub import Repository from transformers import ( CONFIG_MAPPING, @@ -62,6 +65,8 @@ set_seed, ) +from examples.opt_finetune.load_params import load_params_dis_array + alpa.init(cluster="ray") from transformers.testing_utils import CaptureLogger @@ -91,6 +96,9 @@ class TrainingArguments: ) do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) + use_dummy_value: bool = field(default=False, metadata={"help": "Whether to use dummy values for debugging."}) + use_manual_layer: bool = field(default=False, metadata={"help": "Whether to use manual layer annotation."}) + alpa_init: bool = field(default=False, metadata={"help": "Whether to use Alpa's distributed gpu init."}) per_device_train_batch_size: int = field( default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."} ) @@ -320,7 +328,7 @@ def create_learning_rate_fn( return schedule_fn -def monkey_patch_remat(): +def monkey_patch_remat(use_manual_layer=False): # Use monkey patch to add remat for all transformer layers. from transformers.models.opt.modeling_flax_opt import FlaxOPTDecoderLayer, FlaxOPTDecoderLayerCollection from flax.linen.partitioning import remat @@ -349,7 +357,7 @@ def call( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -365,6 +373,10 @@ def call( if output_attentions: all_self_attns += (layer_outputs[1],) + # add manual layer annotations + if use_manual_layer and i != len(self.layers) - 1: + mark_pipeline_boundary() + outputs = [hidden_states, all_hidden_states, all_self_attns] return outputs @@ -512,6 +524,11 @@ def main(): use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: + if "175b" in model_args.model_name_or_path.lower(): + raise RuntimeError("HuggingFace hub does not have OPT-175B model. If you want to finetune OPT-175B, " + "please pass --config_name instead and set it as the path to the config file " + "provided at examples/opt_finetune/config_175b.json. Please obtain the weights" + "by contacting Meta Inc.") config = AutoConfig.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, @@ -522,7 +539,7 @@ def main(): logger.warning("You are instantiating a new config instance from scratch.") if training_args.use_remat: - monkey_patch_remat() + monkey_patch_remat(training_args.use_manual_layer) if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( @@ -545,6 +562,7 @@ def main(): "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) + do_init = not training_args.alpa_init if model_args.model_name_or_path: model = FlaxAutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, @@ -552,6 +570,7 @@ def main(): seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), use_auth_token=True if model_args.use_auth_token else None, + _do_init=do_init ) #from transformers import FlaxOPTForCausalLM #config.num_hidden_layers = 2 @@ -565,8 +584,33 @@ def main(): config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), + _do_init=do_init ) + # from jax.tree_util import tree_flatten, tree_unflatten, tree_leaves + # + # file_path = "/tmp/66B_np" + # os.makedirs(file_path, exist_ok=True) + # + # def save_to_disk(params, prefix=""): + # if isinstance(params, dict): + # for key in params: + # if key == "model": + # new_prefix = prefix + # elif len(prefix) > 0: + # new_prefix = prefix + "." + key + # else: + # new_prefix = key + # save_to_disk(params[key], new_prefix) + # if isinstance(params, jax.numpy.DeviceArray): + # print(f"Save the tensor {prefix} to {file_path}") + # with open(os.path.join(file_path, prefix), "wb") as g: + # jnp.save(g, params) + # return + # + # save_to_disk(model.params) + # exit(1) + # Preprocessing the datasets. # First we tokenize all the texts. if training_args.do_train: @@ -653,14 +697,17 @@ def group_texts(examples): if data_args.max_train_samples is not None: max_train_samples = min(len(train_dataset), data_args.max_train_samples) train_dataset = train_dataset.select(range(max_train_samples)) - - if training_args.do_eval: - if "validation" not in tokenized_datasets: - raise ValueError("--do_eval requires a validation dataset") - eval_dataset = lm_datasets["validation"] - if data_args.max_eval_samples is not None: - max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) - eval_dataset = eval_dataset.select(range(max_eval_samples)) + new_datasets = [] + for i in range(50): + new_datasets.append(copy.deepcopy(train_dataset)) + train_dataset = datasets.concatenate_datasets(new_datasets) + + if "validation" not in tokenized_datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = lm_datasets["validation"] + if data_args.max_eval_samples is not None: + max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) + eval_dataset = eval_dataset.select(range(max_eval_samples)) # Adjust batch size and num_micro_batches for small datasets num_devices = alpa.get_global_num_devices() @@ -698,8 +745,19 @@ def group_texts(examples): # Store some constant num_epochs = int(training_args.num_train_epochs) - train_batch_size = int(training_args.per_device_train_batch_size) * num_devices - eval_batch_size = int(training_args.per_device_eval_batch_size) * num_devices + # Infer data parallel + op = training_args.operator_parallel + pp = training_args.pipeline_parallel + dp = num_devices // op // pp + assert dp >= 1, "You settings of `operator_parallel` or `pipeline_parallel` is problematic. " \ + "Please make sure op * pp is divisible by num_devices" + # Copy the parallel config into configs + setattr(config, "op", op) + setattr(config, "pp", pp) + setattr(config, "dp", dp) + + train_batch_size = int(training_args.per_device_train_batch_size) * dp + eval_batch_size = int(training_args.per_device_eval_batch_size) * dp steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs @@ -754,8 +812,9 @@ def decay_mask_fn(params): alpa.global_config.flax_always_use_fp16_embedding = True else: use_master_copy = dynamic_scale = None - state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, - dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) + + if training_args.use_dummy_value: + alpa.global_config.use_dummy_value_for_benchmarking = True def loss_fn(logits, labels): shift_logits = logits[..., :-1, :] @@ -806,18 +865,73 @@ def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, deterministic=True)[0] loss = loss_fn(logits, labels) - # summarize metrics metrics = {"loss": loss} return metrics - # Create parallel version of the train and eval step - method = alpa.get_3d_parallel_method( - num_micro_batches=training_args.num_micro_batches, - data_parallel=-1, - operator_parallel=training_args.operator_parallel, - pipeline_parallel=training_args.pipeline_parallel) + use_create_state_parallel = True + + if not use_create_state_parallel and training_args.alpa_init: + # In this case, params are not initialized yet, and we use shaped array to trigger alpa compilation + rngkey = jax.core.ShapedArray((2,), jnp.uint32) + input_ids = jax.core.ShapedArray((1, 128), jnp.int32) + attention_mask = jax.core.ShapedArray((1, 128), jnp.int32) + position_ids = jax.core.ShapedArray((1, 128), jnp.int32) + params_aval = unfreeze(jax.eval_shape(model.module.init, rngkey, input_ids, attention_mask, position_ids)["params"]) + state_aval = TrainState.create_aval(apply_fn=model.__call__, params=params_aval, tx=optimizer, + dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) + elif use_create_state_parallel: + def create_state(): + model = FlaxAutoModelForCausalLM.from_config( + config, + seed=training_args.seed, + dtype=getattr(jnp, model_args.dtype), + _do_init=do_init + ) + optimizer = optax.chain( + optax.clip_by_global_norm(1.0), + optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + eps=training_args.adam_epsilon, + weight_decay=training_args.weight_decay, + mask=decay_mask_fn) + ) + rngkey = jnp.ones((2,), jnp.uint32) + input_ids = jnp.ones((1, 128), jnp.int32) + attention_mask = jnp.ones((1, 128), jnp.int32) + position_ids = jnp.ones((1, 128), jnp.int32) + # params_aval = unfreeze( + # jax.eval_shape(model.module.init, rngkey, input_ids, attention_mask, position_ids)["params"]) + params = unfreeze(model.module.init(rngkey, input_ids, attention_mask, position_ids)["params"]) + state_aval = TrainState.create(apply_fn=model.__call__, params=params, tx=optimizer, + dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) + return state_aval + else: + # In this case, params have been initialized by HF in the CPU memory of the driver node. + state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, + dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) + # Create parallel version of the train and eval step + if training_args.use_manual_layer: + assert config.num_hidden_layers % pp == 0, f"There are {config.num_hidden_layers} layers but {pp} stages." + n_layers_per_stage = config.num_hidden_layers // pp + forward_stage_layer_ids = [list(range(i * n_layers_per_stage, (i + 1) * n_layers_per_stage)) for i in range(pp)] + print(f"n_layers_per_stage {n_layers_per_stage}, forward_stage_layer_ids {forward_stage_layer_ids}") + method = alpa.get_3d_parallel_method( + num_micro_batches=training_args.num_micro_batches, + data_parallel=-1, + operator_parallel=training_args.operator_parallel, + pipeline_parallel=training_args.pipeline_parallel, + use_manual_layer_option=True, + forward_stage_layer_ids=forward_stage_layer_ids) + else: + method = alpa.get_3d_parallel_method( + num_micro_batches=training_args.num_micro_batches, + data_parallel=-1, + operator_parallel=training_args.operator_parallel, + pipeline_parallel=training_args.pipeline_parallel) p_train_step = alpa.parallelize(train_step, method=method, donate_argnums=(0,)) @@ -825,6 +939,83 @@ def eval_step(params, batch): method=alpa.FollowParallel( p_train_step, num_micro_batches=eval_num_micro_batches)) + if training_args.alpa_init: + print(f" - Compile executables. ", end="", flush=True) + tic = time.time() + # trigger compile + seq_len = data_args.block_size + train_input_shape = (train_batch_size, seq_len) + batch_aval = {"input_ids": + jax.core.ShapedArray( + train_input_shape, jnp.int32), + "position_ids": + jax.core.ShapedArray( + train_input_shape, jnp.int32), + "attention_mask": + jax.core.ShapedArray( + train_input_shape, jnp.int32), + "labels": + jax.core.ShapedArray( + train_input_shape, jnp.int32), + } + + if use_create_state_parallel: + # batch_aval = { + # "input_ids": jnp.ones(train_input_shape, jnp.int32), + # "position_ids": jnp.ones(train_input_shape, jnp.int32), + # "attention_mask": jnp.ones(train_input_shape, jnp.int32), + # "labels": jnp.ones(train_input_shape, jnp.int32), + # } + p_create_state = alpa.parallelize(create_state, method=CreateStateParallel(p_train_step, batch_aval)) + state_aval = p_create_state() + + train_executable = p_train_step.get_executable(state_aval, batch_aval) + eval_executable = p_eval_step.get_executable(state_aval.params, batch_aval) + train_executable.sync() + model._is_initialized = True + print(f" Compilation takes {time.time() - tic:.2f} seconds.") + + if not training_args.use_dummy_value: + print(" - Load parameters. ", end="", flush=True) + tic = time.time() + assert config.weight_path and os.path.exists(config.weight_path), f"Cannot find weight at {config.weight_path}" + params = load_params_dis_array(config.weight_path, train_executable, state_aval.params, config, dummy=False) + train_executable.sync() + print(f" Load parameters takes {time.time() - tic:.2f} seconds.") + # from alpa.serialization import restore_checkpoint + # params = restore_checkpoint(path, ) + # Work around the model._initialized in HF + model._is_initialized = True + model.params = params + tic = time.time() + # state = TrainState.create_from(train_state=state_aval, + # params=params, + # dynamic_scale=dynamic_scale, + # use_master_copy=True) + + state = TrainState.create(apply_fn=model.__call__, params=params, tx=optimizer, + dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) + + def compare_dict(state1, state2): + print("Compare params: ") + params1 = jax.tree_util.tree_flatten(state1) + params2 = jax.tree_util.tree_flatten(state2) + for p1, p2 in zip(params1[0], params2[0]): + assert type(p1) == type(p2) + if isinstance(p1, alpa.device_mesh.DistributedArray): + print(f"Sharding spec 1: {p1.sharding_spec}, sharding spec 2 {p2.sharding_spec}") + assert p1.sharding_spec == p2.sharding_spec, f"wrong at {p1} and {p2}" + else: + assert isinstance(p1, alpa.device_mesh.ReplicatedDistributedArray) + print(f"Sharding spec 1: {p1.replica.sharding_spec}, sharding spec 2 {p2.replica.sharding_spec}") + assert p1.replica.sharding_spec == p2.replica.sharding_spec, f"wrong at {p1} and {p2}" + + + + # state = TrainState.create_distributed(apply_fn=model.__call__, params=model.params, tx=optimizer, + # dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) + print(f" Create train states takes {time.time() - tic:.2f} seconds.") + dump_debug_info_train_step = dump_debug_info_eval_step = True logger.info("***** Running training *****") @@ -853,13 +1044,28 @@ def eval_step(params, batch): # Generate an epoch by shuffling sampling indices from the train dataset train_loader = data_loader(input_rng, train_dataset, train_batch_size, train_min_batch_size, shuffle=True) - steps_per_epoch = len(train_dataset) // train_batch_size + + if training_args.use_dummy_value: + steps_per_epoch = 50 + else: + steps_per_epoch = len(train_dataset) // train_batch_size # train for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): - batch = next(train_loader) - batch["position_ids"] = (batch["attention_mask"].cumsum(axis=1) * - batch["attention_mask"]) - 1 - state, train_metric = p_train_step(state, batch) + if not training_args.use_dummy_value: + batch = next(train_loader) + batch["position_ids"] = (batch["attention_mask"].cumsum(axis=1) * + batch["attention_mask"]) - 1 + else: + batch = batch_aval + # batch = next(train_loader) + # batch["position_ids"] = (batch["attention_mask"].cumsum(axis=1) * + # batch["attention_mask"]) - 1 + # for b in batch: + # batch[b] = batch[b].astype(jnp.int32) + if training_args.use_dummy_value: + state_aval, train_metric = p_train_step(state_aval, batch) + else: + state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) cur_step = epoch * (len(train_dataset) // train_batch_size) + step @@ -868,7 +1074,7 @@ def eval_step(params, batch): dump_debug_info_train_step = False executable = p_train_step.get_last_executable() executable.sync() - executable.dump_debug_info("alpa_debug_info") + executable.dump_debug_info("/tmp/alpa_debug_info") epochs.write(f"Initial compilation completed. " f"Time elapsed: {time.time() - train_start:.2f} s") @@ -897,9 +1103,10 @@ def eval_step(params, batch): epochs.write( f"Step... {cur_step} | " f"Loss: {train_metric['loss'].mean():.4f}, " - f"Learning Rate: {train_metric['learning_rate'].mean():.5f}, " + f"Learning Rate: {train_metric['learning_rate'].mean():.8f}, " f"Throughput: {throughput_tokens:.2f} token/s, " - f"{throughput_tflops:.2f} TFLOP/s" + f"{throughput_tflops:.2f} TFLOP/s, " + f"{throughput_tflops * 4 / 3:.2f} TFLOP/s." ) train_metrics = [] @@ -983,11 +1190,12 @@ def eval_step(params, batch): json.dump(eval_metrics, f, indent=4, sort_keys=True) # Save the final model - epochs.write("\nSave the final model...") - alpa.prefetch(state.params) - params = alpa.util.map_to_nparray(state.params) - model.save_pretrained(training_args.output_dir, params=params) - tokenizer.save_pretrained(training_args.output_dir) + if not training_args.use_dummy_value: + epochs.write("\nSave the final model...") + alpa.prefetch(state.params) + params = alpa.util.map_to_nparray(state.params) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir) if __name__ == "__main__": diff --git a/examples/opt_finetune/test_save_load.py b/examples/opt_finetune/test_save_load.py new file mode 100644 index 000000000..4d7c82de5 --- /dev/null +++ b/examples/opt_finetune/test_save_load.py @@ -0,0 +1,65 @@ +import os +import tempfile + +import ray + +import alpa +from alpa import (init, shutdown, parallelize, DistributedArray, + PipeshardParallel, save_checkpoint, restore_checkpoint) +from alpa.device_mesh import get_global_cluster +from alpa.testing import get_bert_layer_train_state_and_step, assert_allclose +from alpa.parallel_method import get_3d_parallel_method + + +def _get_save_prefix(): + device_cluster = get_global_cluster() + if len(device_cluster.host_info) > 1: + raise RuntimeError("The multi-host test requires a mounted EFS! ") + else: + # Use tmp dir for the single-host test + save_prefix = "/tmp/" + return save_prefix + + +alpa.init() +ckpt_dir = "/mnt/alpa-opt/alpa/examples/opt_finetune/test_ckpt" +state, batch, train_step = get_bert_layer_train_state_and_step( + batch_size=16, + seq_len=8, + num_layers=2, + hidden_size=128, + num_heads=8, + clip_by_global_norm=False, + use_dynamic_scale=False, + add_manual_pipeline_marker=True) + +method = PipeshardParallel(num_micro_batches=2, layer_option="manual") + +serial_train_step = train_step +parallel_train_step = parallelize(train_step, method=method) +executable = parallel_train_step.get_executable(state, batch) + +serial_state = state +parallel_state = state +serial_state = serial_train_step(serial_state, batch)[0] +parallel_state = parallel_train_step(parallel_state, batch)[0] +# assert_allclose(serial_state.params, parallel_state.params, 1e-3, 1e-3) + +with tempfile.TemporaryDirectory(prefix="/tmp/") as cache_dir: + # Save checkpoint + save_checkpoint(ckpt_dir, parallel_state, 1, cache_dir) + + # Sync all the move workers + executable.sync_move_workers() + + # Restore checkpoint + state_ps, _ = executable.get_input_placement_specs() + load_state = restore_checkpoint(ckpt_dir, 1, state_ps) + + # Run after load + serial_state = serial_train_step(serial_state, batch)[0] + load_state = parallel_train_step(load_state, batch)[0] + + # Check results + assert_allclose(serial_state.params, load_state.params, 1e-3, + 1e-3)