Skip to content

Commit 7d7928f

Browse files
committed
feat: cudnn
1 parent c1c9df6 commit 7d7928f

8 files changed

Lines changed: 294 additions & 76 deletions

File tree

buildSrc/src/main/kotlin/io/github/sgpublic/DockerPlugin.kt

Lines changed: 96 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,7 @@ package io.github.sgpublic
33
import com.bmuschko.gradle.docker.DockerRemoteApiPlugin
44
import com.bmuschko.gradle.docker.tasks.image.DockerBuildImage
55
import com.bmuschko.gradle.docker.tasks.image.DockerPushImage
6-
import io.github.sgpublic.tasks.CudaDockerfile
7-
import io.github.sgpublic.tasks.PlaywrightDockerfile
8-
import io.github.sgpublic.tasks.PoetryDockerfile
9-
import io.github.sgpublic.tasks.PythonVersions
10-
import io.github.sgpublic.tasks.CodaVersions
6+
import io.github.sgpublic.tasks.*
117
import io.github.sgpublic.utils.*
128
import org.gradle.api.Plugin
139
import org.gradle.api.Project
@@ -36,7 +32,7 @@ class DockerPlugin: Plugin<Project> {
3632
"buildPoetry${simplyVersion}${debianVer.name.capitalized()}Image",
3733
DockerBuildImage::class.java
3834
) {
39-
group = "python"
35+
group = "python${simplyVersion}"
4036
inputDir.set(buildDir("poetry"))
4137
buildArgs.putAll(mapOf(
4238
"PYTHON_VERSION" to pyFullVer,
@@ -56,7 +52,7 @@ class DockerPlugin: Plugin<Project> {
5652
"pushPoetry${simplyVersion}${debianVer.name.capitalized()}Image",
5753
DockerPushImage::class.java
5854
) {
59-
group = "python"
55+
group = "python${simplyVersion}"
6056
dependsOn(buildPoetry)
6157
images.addAll(tagsPoetry)
6258
upToDateWhenTagExist(DOCKER_NAMESPACE, DOCKER_REPOSITORY, tagsPoetry.last())
@@ -70,7 +66,7 @@ class DockerPlugin: Plugin<Project> {
7066
"buildPlaywright${simplyVersion}${debianVer.name.capitalized()}Image",
7167
DockerBuildImage::class.java
7268
) {
73-
group = "python"
69+
group = "python${simplyVersion}"
7470
mustRunAfter(buildPoetry)
7571

7672
buildArgs.putAll(mapOf(
@@ -92,7 +88,7 @@ class DockerPlugin: Plugin<Project> {
9288
"pushPlaywright${simplyVersion}${debianVer.name.capitalized()}Image",
9389
DockerPushImage::class.java
9490
) {
95-
group = "python"
91+
group = "python${simplyVersion}"
9692
dependsOn(buildPlaywright)
9793
images.addAll(tagsPlaywright)
9894
upToDateWhenTagExist(DOCKER_NAMESPACE, DOCKER_REPOSITORY, tagsPlaywright.last())
@@ -115,24 +111,28 @@ class DockerPlugin: Plugin<Project> {
115111
val dockerCreateCudaDockerfile = target.tasks.register("cudaDockerfile", CudaDockerfile::class.java) {
116112
destFile.set(buildFile("cuda/Dockerfile"))
117113
}
114+
val dockerCreateCudnnDockerfile = target.tasks.register("cudnnDockerfile", CudnnDockerfile::class.java) {
115+
destFile.set(buildFile("cudnn/Dockerfile"))
116+
}
118117
target.tasks.register("cudaVersions", CodaVersions::class.java)
119118

120119
for ((cudaMinorVer, cudaInfo) in target.CudaVersionsInfo().versions) {
121120
val simplyCudaMinorVer = cudaMinorVer.replace(".", "")
122121
val debCudaMinorVer = cudaMinorVer.replace(".", "-")
123-
for ((debianVer, cudaFullVer) in cudaInfo) {
122+
val cudaMajorVer = debCudaMinorVer.split("-")[0]
123+
for ((debianVer, cudaFullVer) in cudaInfo.debian) {
124124
for ((pyMinorVer, pyInfo) in target.PythonVersionsInfo().versions) {
125125
val simplePyMinorVer = pyMinorVer.replace(".", "")
126126
val pyFullVer = pyInfo[debianVer] ?: continue
127-
val tagsPoetry = listOf(
127+
val cudaPoetryTags = listOf(
128128
"${DOCKER_TAG}:$pyMinorVer-$debianVer-cuda$cudaMinorVer",
129129
"${DOCKER_TAG}:${pyFullVer}-$debianVer-cuda$cudaFullVer",
130130
)
131131
val buildCudaPoetry = target.tasks.register(
132132
"buildCuda${simplyCudaMinorVer}Poetry${simplePyMinorVer}${debianVer.name.capitalized()}Image",
133133
DockerBuildImage::class.java
134134
) {
135-
group = "cuda"
135+
group = "cuda${simplyCudaMinorVer}"
136136
inputDir.set(buildDir("cuda"))
137137
buildArgs.putAll(mapOf(
138138
"PYTHON_VERSION" to pyFullVer,
@@ -142,19 +142,19 @@ class DockerPlugin: Plugin<Project> {
142142
"CUDA_VERSION" to cudaFullVer,
143143
"CUDA_MINOR_VERSION" to debCudaMinorVer,
144144
))
145-
images.addAll(tagsPoetry)
145+
images.addAll(cudaPoetryTags)
146146
dependsOn(dockerCreateCudaDockerfile)
147147
dockerFile.set(dockerCreateCudaDockerfile.get().destFile)
148-
upToDateWhenTagExist(DOCKER_NAMESPACE, DOCKER_REPOSITORY, tagsPoetry.last())
148+
upToDateWhenTagExist(DOCKER_NAMESPACE, DOCKER_REPOSITORY, cudaPoetryTags.last())
149149
}
150150
val pushCudaPoetry = target.tasks.register(
151151
"pushCuda${simplyCudaMinorVer}Poetry${simplePyMinorVer}${debianVer.name.capitalized()}Image",
152152
DockerPushImage::class.java
153153
) {
154-
group = "cuda"
154+
group = "cuda${simplyCudaMinorVer}"
155155
dependsOn(buildCudaPoetry)
156-
images.addAll(tagsPoetry)
157-
upToDateWhenTagExist(DOCKER_NAMESPACE, DOCKER_REPOSITORY, tagsPoetry.last())
156+
images.addAll(cudaPoetryTags)
157+
upToDateWhenTagExist(DOCKER_NAMESPACE, DOCKER_REPOSITORY, cudaPoetryTags.last())
158158
}
159159

160160
val tagsPlaywright = listOf(
@@ -165,7 +165,7 @@ class DockerPlugin: Plugin<Project> {
165165
"buildCuda${simplyCudaMinorVer}Playwright${simplePyMinorVer}${debianVer.name.capitalized()}Image",
166166
DockerBuildImage::class.java
167167
) {
168-
group = "cuda"
168+
group = "cuda${simplyCudaMinorVer}"
169169
buildArgs.putAll(mapOf(
170170
"PYTHON_VERSION" to pyFullVer,
171171
"DEBIAN_VERSION" to "$debianVer",
@@ -184,7 +184,7 @@ class DockerPlugin: Plugin<Project> {
184184
"pushCuda${simplyCudaMinorVer}Playwright${simplePyMinorVer}${debianVer.name.capitalized()}Image",
185185
DockerPushImage::class.java
186186
) {
187-
group = "cuda"
187+
group = "cuda${simplyCudaMinorVer}"
188188
dependsOn(buildCudaPlaywright)
189189
images.addAll(tagsPlaywright)
190190
upToDateWhenTagExist(DOCKER_NAMESPACE, DOCKER_REPOSITORY, tagsPlaywright.last())
@@ -194,6 +194,83 @@ class DockerPlugin: Plugin<Project> {
194194
pushCudaPoetry.get().enabled = false
195195
pushCudaPlaywright.get().enabled = false
196196
}
197+
198+
for ((cudnnMajorVer, cudnnFullVer) in cudaInfo.cudnn) {
199+
val cudnnPoetryTags = listOf(
200+
"${DOCKER_TAG}:$pyMinorVer-$debianVer-cuda$cudaMinorVer-cudnn$cudnnMajorVer",
201+
"${DOCKER_TAG}:${pyFullVer}-$debianVer-cuda$cudaFullVer-cudnn$cudnnMajorVer",
202+
)
203+
val buildCudnnPoetry = target.tasks.register(
204+
"buildCuda${simplyCudaMinorVer}Cudnn${cudnnMajorVer}Poetry${simplePyMinorVer}${debianVer.name.capitalized()}Image",
205+
DockerBuildImage::class.java
206+
) {
207+
group = "cudnn${cudnnMajorVer}-${simplyCudaMinorVer}"
208+
inputDir.set(buildDir("cudnn"))
209+
buildArgs.putAll(mapOf(
210+
"PYTHON_VERSION" to pyFullVer,
211+
"DEBIAN_VERSION" to "$debianVer",
212+
"CUDA_VERSION" to cudaFullVer,
213+
"FLAVOR" to "",
214+
"CUDNN_VERSION" to cudnnFullVer,
215+
"CUDNN_MAJOR_VERSION" to cudnnMajorVer,
216+
"CUDA_MAJOR_VERSION" to cudaMajorVer,
217+
))
218+
images.addAll(cudnnPoetryTags)
219+
mustRunAfter(buildCudaPoetry)
220+
dependsOn(dockerCreateCudnnDockerfile)
221+
dockerFile.set(dockerCreateCudnnDockerfile.get().destFile)
222+
upToDateWhenTagExist(DOCKER_NAMESPACE, DOCKER_REPOSITORY, cudnnPoetryTags.last())
223+
}
224+
val pushCudnnPoetry = target.tasks.register(
225+
"pushCuda${simplyCudaMinorVer}Cudnn${cudnnMajorVer}Poetry${simplePyMinorVer}${debianVer.name.capitalized()}Image",
226+
DockerPushImage::class.java
227+
) {
228+
group = "cudnn${cudnnMajorVer}-${simplyCudaMinorVer}"
229+
dependsOn(buildCudnnPoetry)
230+
images.addAll(cudnnPoetryTags)
231+
upToDateWhenTagExist(DOCKER_NAMESPACE, DOCKER_REPOSITORY, cudnnPoetryTags.last())
232+
}
233+
234+
val cudnnPlaywrightTags = listOf(
235+
"${DOCKER_TAG}:${pyMinorVer}-${debianVer}-playwright-cuda$cudaMinorVer",
236+
"${DOCKER_TAG}:${pyFullVer}-${debianVer}-playwright-cuda$cudaFullVer",
237+
)
238+
val buildCudnnPlaywright = target.tasks.register(
239+
"buildCuda${simplyCudaMinorVer}Cudnn${cudnnMajorVer}Playwright${simplePyMinorVer}${debianVer.name.capitalized()}Image",
240+
DockerBuildImage::class.java
241+
) {
242+
group = "cudnn${cudnnMajorVer}-${simplyCudaMinorVer}"
243+
buildArgs.putAll(mapOf(
244+
"PYTHON_VERSION" to pyFullVer,
245+
"DEBIAN_VERSION" to "$debianVer",
246+
"CUDA_VERSION" to cudaFullVer,
247+
"FLAVOR" to "-playwright",
248+
"CUDNN_VERSION" to cudnnFullVer,
249+
"CUDNN_MAJOR_VERSION" to cudnnMajorVer,
250+
"CUDA_MAJOR_VERSION" to cudaMajorVer,
251+
))
252+
inputDir.set(buildDir("cudnn"))
253+
images.addAll(cudnnPlaywrightTags)
254+
mustRunAfter(buildCudaPlaywright)
255+
dependsOn(dockerCreateCudnnDockerfile)
256+
dockerFile.set(dockerCreateCudnnDockerfile.get().destFile)
257+
upToDateWhenTagExist(DOCKER_NAMESPACE, DOCKER_REPOSITORY, cudnnPoetryTags.last())
258+
}
259+
val pushCudnnPlaywright = target.tasks.register(
260+
"pushCuda${simplyCudaMinorVer}Cudnn${cudnnMajorVer}Playwright${simplePyMinorVer}${debianVer.name.capitalized()}Image",
261+
DockerPushImage::class.java
262+
) {
263+
group = "cuda${simplyCudaMinorVer}"
264+
dependsOn(buildCudnnPlaywright)
265+
images.addAll(cudnnPoetryTags)
266+
upToDateWhenTagExist(DOCKER_NAMESPACE, DOCKER_REPOSITORY, cudnnPoetryTags.last())
267+
}
268+
269+
if (target.DOCKER_TOKEN == null) {
270+
pushCudnnPoetry.get().enabled = false
271+
pushCudnnPlaywright.get().enabled = false
272+
}
273+
}
197274
}
198275
}
199276
}

buildSrc/src/main/kotlin/io/github/sgpublic/tasks/CodaVersions.kt

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@ import io.github.g00fy2.versioncompare.Version
44
import io.github.sgpublic.utils.DebianVersion
55
import io.github.sgpublic.utils.NetString
66
import io.github.sgpublic.utils.CudaVersionsInfo
7-
import io.github.sgpublic.utils.VersionsInfo
8-
import io.github.sgpublic.utils.choseByTimezone
97
import org.gradle.api.DefaultTask
108
import org.gradle.api.tasks.TaskAction
9+
import kotlin.collections.LinkedHashMap
10+
11+
data class CudaVersion(
12+
val debian: LinkedHashMap<DebianVersion, String> = LinkedHashMap(),
13+
val cudnn: LinkedHashMap<String, String> = LinkedHashMap(),
14+
)
1115

1216
open class CodaVersions: DefaultTask() {
1317
init {
@@ -16,34 +20,59 @@ open class CodaVersions: DefaultTask() {
1620

1721
@TaskAction
1822
fun execute() {
19-
val versions = LinkedHashMap<String, Map<DebianVersion, String>>()
23+
val versions = LinkedHashMap<String, CudaVersion>()
2024
for (debianVersion in DebianVersion.values()) {
2125
versionRss(debianVersion, versions)
2226
}
23-
project.CudaVersionsInfo(VersionsInfo(versions.toSortedMap()))
27+
project.CudaVersionsInfo(CudaVersionsInfo(versions.toSortedMap()))
2428
}
2529

26-
private fun versionRss(debian: DebianVersion, versions: LinkedHashMap<String, Map<DebianVersion, String>>) {
27-
val packageMatch = "Package: cuda-\\d+-\\d+".toRegex()
28-
val versionMatch = "Version: \\d+.\\d+.\\d+-1".toRegex()
30+
private fun versionRss(debian: DebianVersion, versions: LinkedHashMap<String, CudaVersion>) {
31+
val cudaMatch = "Package: cuda-\\d+-\\d+".toRegex()
32+
val cudnnMatch = "Package: cudnn\\d+-cuda-\\d+-\\d+".toRegex()
33+
val cudnnMajorMatch = "cudnn\\d+".toRegex()
34+
val versionMatch = "Version: (.*?)-1".toRegex()
2935
val versionsRss = NetString("https://developer.download.nvidia.com/compute/cuda/repos/debian${debian.numVer}/x86_64/Packages")
30-
.split("\n\n").filter { it.contains(packageMatch) }
31-
for (version in versionsRss) {
36+
.split("\n\n")
37+
val cudaVersionRss = versionsRss.filter { it.contains(cudaMatch) }
38+
val cudnnVersionRss = versionsRss.filter { it.contains(cudnnMatch) }
39+
for (version in cudaVersionRss) {
40+
// cuda
3241
val details = version.split("\n")
33-
val minerVersion = packageMatch.find(details[0])?.value?.let {
42+
val cudaMinerRaw = cudaMatch.find(details[0])?.value?.let {
3443
return@let it.subSequence(14, it.length).toString()
35-
}?.replace("-", ".") ?: continue
36-
val minerVersionInfo = LinkedHashMap<DebianVersion, String>(versions[minerVersion] ?: emptyMap())
37-
val fullVersion = Version(
44+
} ?: continue
45+
val cudaMiner = cudaMinerRaw.replace("-", ".")
46+
val cudaMinerInfo = versions[cudaMiner] ?: CudaVersion()
47+
val cudaFullVer = Version(
3848
versionMatch.find(details[1])?.value?.let {
3949
return@let it.subSequence(9, it.length - 2).toString()
4050
} ?: continue
4151
)
42-
val storedVersion = minerVersionInfo[debian]
43-
if (storedVersion == null || Version(storedVersion) < fullVersion) {
44-
minerVersionInfo[debian] = fullVersion.toString()
52+
val existCudaVer = cudaMinerInfo.debian[debian]
53+
if (existCudaVer == null || Version(existCudaVer) < cudaFullVer) {
54+
cudaMinerInfo.debian[debian] = cudaFullVer.toString()
55+
}
56+
57+
// cudnn
58+
val cudnnItems = cudnnVersionRss.filter { it.contains("cuda-${cudaMinerRaw}") }
59+
for (cudnnItem: String in cudnnItems) {
60+
val cudnnDetails = cudnnItem.split("\n")
61+
val cudnnMajor = cudnnMajorMatch.find(cudnnDetails[0])
62+
?.value?.substring(5)
63+
?: continue
64+
val cudnnFullVer = Version(
65+
versionMatch.find(cudnnDetails[1])?.value?.let {
66+
return@let it.subSequence(9, it.length - 2).toString()
67+
} ?: continue
68+
)
69+
val existCudnnVer = cudaMinerInfo.cudnn[cudnnMajor]
70+
if (existCudnnVer == null || Version(existCudnnVer) < cudnnFullVer) {
71+
cudaMinerInfo.cudnn[cudnnMajor] = cudnnFullVer.toString()
72+
}
4573
}
46-
versions[minerVersion] = minerVersionInfo
74+
75+
versions[cudaMiner] = cudaMinerInfo
4776
}
4877
}
4978

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package io.github.sgpublic.tasks
2+
3+
import com.bmuschko.gradle.docker.tasks.image.Dockerfile
4+
import io.github.sgpublic.DockerPlugin
5+
import io.github.sgpublic.utils.*
6+
import org.gradle.api.tasks.*
7+
8+
abstract class CudnnDockerfile: Dockerfile() {
9+
// @InputDirectory
10+
// fun getInputDir(): java.io.File {
11+
// return project.file("./src/main/cudnn")
12+
// }
13+
14+
@OutputDirectory
15+
fun getOutputDir(): java.io.File {
16+
return buildDir("cudnn").get().asFile
17+
}
18+
19+
@TaskAction
20+
override fun create() {
21+
this.arg("PYTHON_VERSION")
22+
this.arg("DEBIAN_VERSION")
23+
this.arg("CUDA_VERSION")
24+
this.arg("FLAVOR")
25+
this.from(From("${DockerPlugin.DOCKER_TAG}:\${PYTHON_VERSION}-\${DEBIAN_VERSION}\${FLAVOR}-cuda\${CUDA_VERSION}"))
26+
27+
this.arg("CUDNN_VERSION")
28+
this.environmentVariable("CUDNN_VERSION", "\${CUDNN_VERSION}")
29+
this.arg("CUDNN_MAJOR_VERSION")
30+
this.arg("CUDA_MAJOR_VERSION")
31+
runCommand(command(
32+
aptInstall(
33+
"cudnn\${CUDNN_MAJOR_VERSION}-cuda-\${CUDA_MAJOR_VERSION}"
34+
),
35+
))
36+
37+
super.create()
38+
}
39+
40+
override fun getGroup(): String {
41+
return "cudnn"
42+
}
43+
}

buildSrc/src/main/kotlin/io/github/sgpublic/tasks/PythonVersions.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ open class PythonVersions: DefaultTask() {
5353
checkedVersions[minorVer.toString()] = storedMinorVer
5454
}
5555

56-
project.PythonVersionsInfo(VersionsInfo(checkedVersions))
56+
project.PythonVersionsInfo(PyVersionsInfo(checkedVersions))
5757
}
5858

5959
override fun getGroup(): String {

0 commit comments

Comments
 (0)