73 lines
2.6 KiB
Text
73 lines
2.6 KiB
Text
|
FROM rocm/pytorch:latest
|
||
|
|
||
|
ARG PYTORCH_ROCM_ARCH
|
||
|
ARG HSA_OVERRIDE_GFX_VERSION
|
||
|
ARG MODEL_URL
|
||
|
|
||
|
SHELL ["/bin/bash", "-c"]
|
||
|
RUN conda init bash && \
|
||
|
echo "conda activate py_3.9" >> ~/.bashrc
|
||
|
|
||
|
RUN apt-get update && apt-get install -y \
|
||
|
ffmpeg \
|
||
|
git \
|
||
|
wget \
|
||
|
libomp5 \
|
||
|
&& rm -rf /var/lib/apt/lists/*
|
||
|
|
||
|
WORKDIR /opt/whisperx
|
||
|
|
||
|
RUN git clone https://github.com/arlo-phoenix/CTranslate2-rocm.git --recurse-submodules && \
|
||
|
cd CTranslate2-rocm && \
|
||
|
source ~/.bashrc && \
|
||
|
CLANG_CMAKE_CXX_COMPILER=clang++ \
|
||
|
CXX=clang++ \
|
||
|
HIPCXX="$(hipconfig -l)/clang" \
|
||
|
HIP_PATH="$(hipconfig -R)" \
|
||
|
cmake -S . -B build \
|
||
|
-DWITH_MKL=OFF \
|
||
|
-DWITH_HIP=ON \
|
||
|
-DCMAKE_HIP_ARCHITECTURES=$PYTORCH_ROCM_ARCH \
|
||
|
-DBUILD_TESTS=ON \
|
||
|
-DWITH_CUDNN=ON && \
|
||
|
cmake --build build -- -j$(nproc) && \
|
||
|
cd build && \
|
||
|
cmake --install . --prefix /opt/conda/envs/py_3.9 && \
|
||
|
ldconfig
|
||
|
|
||
|
RUN source ~/.bashrc && \
|
||
|
cd /opt/whisperx/CTranslate2-rocm/python && \
|
||
|
pip install -r install_requirements.txt && \
|
||
|
CPLUS_INCLUDE_PATH=/opt/conda/envs/py_3.9/include \
|
||
|
LIBRARY_PATH=/opt/conda/envs/py_3.9/lib \
|
||
|
python setup.py bdist_wheel && \
|
||
|
pip install dist/*.whl
|
||
|
|
||
|
RUN source ~/.bashrc && \
|
||
|
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1 --force-reinstall && \
|
||
|
pip install transformers pandas nltk pyannote.audio==3.1.1 faster-whisper==1.0.1 -U && \
|
||
|
pip install whisperx==3.1.1 --no-deps
|
||
|
|
||
|
# Patch the asr.py file to fix the bug
|
||
|
RUN sed -i '/"suppress_numerals": False/a \ "max_new_tokens": None,\n "clip_timestamps": None,\n "hallucination_silence_threshold": None,' \
|
||
|
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/whisperx/asr.py
|
||
|
|
||
|
# Patch the vad.py file to update the VAD_SEGMENTATION_URL
|
||
|
RUN sed -i 's|https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin|$MODEL_URL|' \
|
||
|
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/whisperx/vad.py
|
||
|
|
||
|
# Patch the checksum validation in vad.py
|
||
|
RUN sed -i '/if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split/,+3d' \
|
||
|
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/whisperx/vad.py
|
||
|
|
||
|
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/envs/py_3.9/lib/
|
||
|
|
||
|
# Create symlink for libiomp5
|
||
|
RUN ln -s /opt/rocm-6.3.1/lib/llvm/lib/libiomp5.so /usr/lib/libiomp5.so && \
|
||
|
ldconfig
|
||
|
|
||
|
# Create an entry script
|
||
|
RUN echo '#!/bin/bash\nsource ~/.bashrc\nwhile true; do sleep 86400; done' > /entrypoint.sh && \
|
||
|
chmod +x /entrypoint.sh
|
||
|
|
||
|
ENTRYPOINT ["/entrypoint.sh"]
|