Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Makefile for gotch - PyTorch 2.10.0 Go bindings
# Architecture-specific configuration is loaded from arch/ directory

SHELL := /bin/bash

# Detect OS and architecture
OS := $(shell uname -s | tr '[:upper:]' '[:lower:]')
ARCH := $(shell uname -m)

# Normalize architecture names
ifeq ($(ARCH),x86_64)
ARCH := amd64
else ifeq ($(ARCH),aarch64)
ARCH := arm64
endif

# Allow override of architecture configuration
# Examples:
# make ARCH_CONFIG=linux-amd64-cuda build
# make ARCH_CONFIG=darwin-arm64 test
ARCH_CONFIG ?= $(OS)-$(ARCH)

# Include architecture-specific configuration
ARCH_FILE := arch/$(ARCH_CONFIG).mk
ifeq ($(wildcard $(ARCH_FILE)),)
$(error Architecture config file not found: $(ARCH_FILE). Available: $(wildcard arch/*.mk))
endif
include $(ARCH_FILE)

# Go test flags
TEST_FLAGS := -v
TEST_TIMEOUT := 5m

.PHONY: all build test test-nn test-ts clean help ffi-validate

# Default target
all: build

# Build all core packages
build:
@echo "Building gotch core packages..."
@go build -v . ./ts ./nn ./vision

# Run all tests
test: test-nn test-ts ffi-validate
@echo "All tests completed"

# Run nn package tests
# Running with -p 1 to force sequential execution (PyTorch 2.10.0 thread-local gradient state)
test-nn:
@echo "Running nn package tests..."
@go test $(TEST_FLAGS) -timeout $(TEST_TIMEOUT) -p 1 -parallel 1 ./nn

# Run ts package tests
test-ts:
@echo "Running ts package tests..."
@go test $(TEST_FLAGS) -timeout $(TEST_TIMEOUT) ./ts

# Run specific test in nn package
# Usage: make test-nn-specific TEST=TestInitTensor_Memcheck
test-nn-specific:
@echo "Running specific nn test: $(TEST)..."
@go test $(TEST_FLAGS) -timeout $(TEST_TIMEOUT) -run $(TEST) ./nn

# Run specific test in ts package
# Usage: make test-ts-specific TEST=TestTensor
test-ts-specific:
@echo "Running specific ts test: $(TEST)..."
@go test $(TEST_FLAGS) -timeout $(TEST_TIMEOUT) -run $(TEST) ./ts

# Run tests with coverage
test-nn-coverage:
@echo "Running nn tests with coverage..."
@go test -v -timeout $(TEST_TIMEOUT) -coverprofile=coverage-nn.out ./nn
@go tool cover -html=coverage-nn.out -o coverage-nn.html
@echo "Coverage report saved to coverage-nn.html"

test-ts-coverage:
@echo "Running ts tests with coverage..."
@go test -v -timeout $(TEST_TIMEOUT) -coverprofile=coverage-ts.out ./ts
@go tool cover -html=coverage-ts.out -o coverage-ts.html
@echo "Coverage report saved to coverage-ts.html"

# Clean build artifacts
clean:
@echo "Cleaning build artifacts..."
@go clean -cache -testcache
@rm -f coverage-*.out coverage-*.html

# Display build environment
env:
@echo "Build Environment:"
@echo " Platform: $(PLATFORM_DESC)"
@echo " Arch Config: $(ARCH_CONFIG) ($(ARCH_FILE))"
@echo " LIBTORCH_PATH: $(LIBTORCH_PATH)"
@echo " $(RUNTIME_LIB_VAR): $($(RUNTIME_LIB_VAR))"
@echo ""
@echo "CGO Flags:"
@echo " CGO_CFLAGS: $(CGO_CFLAGS)"
@echo " CGO_LDFLAGS: $(CGO_LDFLAGS)"
@echo " CGO_CXXFLAGS: $(CGO_CXXFLAGS)"
@echo ""
@echo "Go version:"
@go version
@echo ""
@echo "LibTorch version: 2.10.0"

# Check if MPS is available
check-mps:
@echo "Checking MPS availability..."
@go run -exec 'env DYLD_LIBRARY_PATH=$(DYLD_LIBRARY_PATH)' tools/check_device.go || echo "Create tools/check_device.go to test device availability"

# Validate FFI type conversions
ffi-validate:
@echo "Validating FFI type conversions..."
@go run tools/ffi-validation/main.go

# Help target
help:
@echo "Gotch Makefile - PyTorch 2.10.0 Go Bindings"
@echo ""
@echo "Current Configuration:"
@echo " Platform: $(PLATFORM_DESC)"
@echo " Arch Config: $(ARCH_CONFIG)"
@echo " LibTorch: $(LIBTORCH_PATH)"
@echo ""
@echo "Targets:"
@echo " make build - Build all core packages"
@echo " make test - Run all tests (nn + ts)"
@echo " make test-nn - Run nn package tests"
@echo " make test-ts - Run ts package tests"
@echo " make test-nn-specific - Run specific nn test (TEST=TestName)"
@echo " make test-ts-specific - Run specific ts test (TEST=TestName)"
@echo " make test-nn-coverage - Run nn tests with coverage report"
@echo " make test-ts-coverage - Run ts tests with coverage report"
@echo " make clean - Clean build artifacts and caches"
@echo " make env - Display build environment"
@echo " make check-mps - Check MPS device availability"
@echo " make ffi-validate - Validate FFI type conversions (C <-> Go)"
@echo " make help - Show this help message"
@echo ""
@echo "Architecture Configuration:"
@echo " Default: auto-detected (current: $(ARCH_CONFIG))"
@echo " Override: make ARCH_CONFIG=linux-amd64-cuda build"
@echo " Available configs: $(notdir $(basename $(wildcard arch/*.mk)))"
@echo ""
@echo "Examples:"
@echo " make test-nn"
@echo " make test-nn-specific TEST=TestInitTensor_Memcheck"
@echo " make ARCH_CONFIG=linux-amd64-cuda build"
@echo " LIBTORCH_PATH=/custom/path make build"
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@

## Dependencies

- **Libtorch** C++ v2.1.0 library of [Pytorch](https://pytorch.org/)
- **Libtorch** C++ v2.10.0 library of [Pytorch](https://pytorch.org/)
- Clang-17/Clang++-17 compilers

## Installation

- Default CUDA version is `11.8` if CUDA is available otherwise using CPU version.
- Default Pytorch C++ API version is `2.1.0`
- Default Pytorch C++ API version is `2.10.0`

**NOTE**: `libtorch` will be installed at **`/usr/local/lib`**

Expand Down Expand Up @@ -266,6 +266,20 @@ func main() {

- See [pkg.go.dev](https://pkg.go.dev/github.com/sugarme/gotch?tab=doc) for APIs detail.

For unit tests use:
- make test

## PyTorch 2.10.0 Upgrade Notes

This version includes critical fixes for PyTorch 2.10.0 compatibility:
- Fixed gradient state management (thread-local in PyTorch 2.10.0)
- Fixed FFI type conversion bugs (C.int → Go int)

To validate FFI conversions:
```bash
make ffi-validate
```

## License

`gotch` is Apache 2.0 licensed.
Expand Down
38 changes: 37 additions & 1 deletion device.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ type Device struct {
}

type Cuda Device
type Mps Device

var (
CPU Device = Device{Name: "CPU", Value: -1}
MPS Mps = Mps{Name: "MPS", Value: -2}
CUDA Cuda = Cuda{Name: "CUDA", Value: 0}
)

Expand All @@ -34,6 +36,14 @@ func NewCuda() Device {
return CudaBuilder(0)
}

// MPS methods:
// ============

// IsAvailable returns true if MPS (Metal Performance Shaders) support is available
func (m Mps) IsAvailable() bool {
return lib.AtcMpsIsAvailable()
}

// Cuda methods:
// =============

Expand Down Expand Up @@ -74,6 +84,8 @@ func (d Device) CInt() CInt {
switch {
case d.Name == "CPU":
return -1
case d.Name == "MPS":
return -2
case d.Name == "CUDA":
// TODO: create a function to retrieve cuda_index
var deviceIndex int = d.Value
Expand All @@ -87,7 +99,9 @@ func (d Device) CInt() CInt {
func (d Device) OfCInt(v CInt) Device {
switch {
case v == -1:
return Device{Name: "CPU", Value: 1}
return Device{Name: "CPU", Value: -1}
case v == -2:
return Device{Name: "MPS", Value: -2}
case v >= 0:
return CudaBuilder(uint(v))
default:
Expand Down Expand Up @@ -124,3 +138,25 @@ func CudaIfAvailable() Device {
return CPU
}
}

// MPSIfAvailable returns an MPS device if available, else CPU.
func MPSIfAvailable() Device {
switch {
case MPS.IsAvailable():
return Device{Name: "MPS", Value: -2}
default:
return CPU
}
}

// BestAvailableDevice returns the best available device (CUDA > MPS > CPU).
func BestAvailableDevice() Device {
switch {
case CUDA.IsAvailable():
return CudaBuilder(0)
case MPS.IsAvailable():
return Device{Name: "MPS", Value: -2}
default:
return CPU
}
}
3 changes: 2 additions & 1 deletion gen/gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ let excluded_functions =
; "sym_size"
; "sym_stride"
; "sym_storage_offset"
; "_sparse_semi_structured_addmm"
]

let no_tensor_options =
Expand Down Expand Up @@ -1441,7 +1442,7 @@ let run ~yaml_filename ~cpp_filename ~ffi_filename ~must_wrapper_filename


let () =
run ~yaml_filename:"gen/pytorch/Declarations-v2.1.0.yaml"
run ~yaml_filename:"gen/pytorch/Declarations-v2.10.0.yaml"
~cpp_filename:"libtch/torch_api_generated"
~ffi_filename:"libtch/c-generated.go"
~must_wrapper_filename:"ts/must-tensor-generated.go"
Expand Down
Loading