检查Anaconda环境与PyTorch配置
在开始深度学习项目前,首先需要确认当前Anaconda环境中是否已正确安装Python和PyTorch相关组件。可通过以下命令查看版本信息:
- 查看Python版本:python --version
- 查看PyTorch版本:pip show torch 或 import torch; print(torch.__version__)
conda list -n pytorch
创建Python虚拟环境
使用Conda可以方便地创建独立的Python运行环境,避免依赖冲突。创建命令如下:
conda create -n <环境名字> python=<版本号>
进阶技巧:快速创建并安装依赖
若希望在创建环境时一并安装常用库(如numpy、matplotlib等),可在命令后追加包名:
conda create -n myenv python=3.9 numpy matplotlib
此外,为跳过每次确认输入“y”的步骤,可添加-y参数实现自动确认:
conda create -n myenv python=3.9 -y
conda create -n gemini_env python=3.11 -y
conda create -n gemini_env python=3.11 google-generativeai jupyter -y
错误环境的处理方式
若误创建了不正确的环境,可通过以下命令将其删除:
conda remove -n <环境名> --all
该操作将彻底移除指定环境及其所有安装的包。
conda env remove -n gemini_env
Python与PyTorch的安装
若系统中未安装合适版本的Python,可使用pip进行安装(以Python 3.10为例):
pip install python3.10.*
conda install pytorch torchvision torchaudio -c pytorch
PyTorch中的张量(Tensor)基础
Tensor是PyTorch中最核心的数据结构,类似于NumPy中的数组,但支持GPU加速和自动求导。常见创建方式包括:
import torch
t = torch.tensor([1.2, 3.4])
torch.tensor((2,2))
torch.ones_like()
torch.zeros_like()
torch.rand((2,2))
torch.rand([2,2])
torch.rand((2,2,))
torch.rand([2,2,])
张量属性(Tensor Attribution)
每个张量都具备若干关键属性,用于描述其数据特征:
- a.dtype:表示张量中元素的数据类型(如torch.float32)
- a.shape:返回张量各维度的大小,形式为元组
- a.device:指示张量存储位置(CPU或CUDA设备)
is_tensor()
torch.is_nonzero(a)
判断对象是否为PyTorch张量
可通过isinstance函数判断一个变量是否为torch.Tensor类型:
isinstance(t, torch.Tensor)
torch.numel()
numel() 与 len() 的关键区别
两者均用于获取张量的长度信息,但含义不同:
- len(t):仅返回第0维(第一维)的长度,即张量最外层的元素个数。
- t.numel():返回张量中所有元素的总数,等价于各维度大小的乘积。
例如,对于形状为(2, 3, 4)的张量:
- len(t) 输出 2
- t.numel() 输出 24
a=torch.zeros([5,5,])
torch.set_default_tensor_type(torch.DoubleTensor)
修改PyTorch默认张量类型
PyTorch在创建浮点型张量时,默认使用float32(torch.float)。可通过以下代码更改为double类型(float64):
import torch
# 设置默认张量类型为 DoubleTensor
torch.set_default_tensor_type(torch.DoubleTensor)
# 新建浮点张量将自动使用 float64
b = torch.tensor([1.2, 3.4])
print(b.dtype) # 输出: torch.float64
# 注意:整数张量不受影响
c = torch.tensor([1, 2])
print(c.dtype) # 输出: torch.int64
torch.arange()
张量拼接操作:torch.cat()
torch.cat() 是 concatenate 的缩写,用于沿指定维度连接两个或多个张量,类似于拼积木。
示例:设有两个矩阵 a(2×2)与 b(2×3),在列方向(dim=1)上进行拼接:
result = torch.cat((a, b), dim=1)
结果将是一个形状为 (2, 5) 的新张量。
若在行方向(dim=0)拼接,则扩展的是行数,即纵向堆叠。
torch.cat([a, b], dim=1)
三维张量的理解与拼接
三维张量常用于表示彩色图像或具有通道结构的数据,其形状通常表示为:
(Channel, Height, Width)
也可理解为“层数、高度、宽度”。
假设存在两个形状均为 (2, 3, 4) 的张量 t1 和 t2:
- 2 表示有 2 层(如颜色通道)
- 3 表示高度方向有 3 行
- 4 表示宽度方向有 4 列
t1 = torch.rand(2, 3, 4)
t2 = torch.rand(2, 3, 4)
print(f"原始形状: {t1.shape}")
不同维度下的拼接效果
1. dim=0:沿“厚度”方向拼接
即将 t2 叠加在 t1 上方,增加层数:
- 直观类比:两片面包叠成双层汉堡
- 数学变化:2 + 2 = 4
- 输出形状:(4, 3, 4)
torch.range()
torch.eye()
torch.full()
其他常用张量函数简介
除了拼接外,还有几个重要的张量操作函数:
- torch.chunk():将张量沿某一维度均分为若干块
- torch.gather():根据索引从张量中提取特定元素
- torch.reshape():改变张量的形状而不改变其数据
torch.cat()
区间表示说明:
- 左闭右开区间:包含起始值,不包含结束值(如[0, 5))
- 闭区间:同时包含起始与结束值(如[0, 5])
torch.range()

1. dim=0:拼接“厚度/层数”(变厚了)
当使用 dim=0 进行拼接时,表示在第 0 个维度上进行连接。可以理解为将张量 t2 放在 t1 的上方或下方,形成一个新的堆叠结构。
实际应用示例:假设你有一张 RGB 图像(包含3个颜色通道),想要为其添加一个透明度通道(Alpha),从而组合成 RGBA 格式的图像(共4个通道)。此时就可以通过拼接操作实现通道数的扩展。
直观理解:类似于将两个立方体沿垂直方向叠加,整体变得更“厚”。
数学变化:原第0维大小为 2 + 2 = 4。
结果形状:(4, 3, 4)
限制条件:要求其余维度(即高度和宽度)必须完全一致才能拼接。
c0 = torch.cat([t1, t2], dim=0)
print(f"dim=0 拼接后:\n{c0.shape}")
# 输出: torch.Size([4, 3, 4]) <-- 只有第0维发生变化
2. dim=1:拼接“高度/行数”(变高了)
选择 dim=1 表示在第二个维度上进行拼接。你可以想象成把 t2 放置在 t1 的正下方,就像把两张图片上下对接,形成一张更长的图像。
直观理解:图像上下合并,视觉上拉长。
数学变化:3 + 3 = 6。
结果形状:(2, 6, 4)
限制条件:第0维(厚度)与第2维(宽度)的尺寸必须相同,才能完成拼接。
c1 = torch.cat([t1, t2], dim=1)
print(f"dim=1 拼接后:\n{c1.shape}")
# 输出: torch.Size([2, 6, 4]) <-- 仅第1维发生改变
3. dim=2:拼接“宽度/列数”(变宽了)
采用 dim=2 拼接是在第三个维度上进行操作。相当于将 t2 放在 t1 的右侧,实现左右并排的效果。
直观理解:两张图横向拼接,形成一幅全景图。
数学变化:4 + 4 = 8。
结果形状:(2, 3, 8)
限制条件:第0维(厚度)和第1维(高度)需保持一致。
c2 = torch.cat([t1, t2], dim=2)
print(f"dim=2 拼接后:\n{c2.shape}")
# 输出: torch.Size([2, 3, 8]) <-- 仅第2维发生变化
torch.chunk(tensor,chunks,dim=0)
总结与类比:张量如同立体蛋糕
若将张量视作一个长方体形状的蛋糕,不同维度的拼接可类比为对蛋糕不同方向的扩展:
- dim=0:增加蛋糕的层数(如奶油层、夹心层叠加),底面积不变,整体变得更厚。
- dim=1:延长蛋糕的长度(前后方向拼接),横截面与宽度不变,蛋糕变得更长。
- dim=2:扩展蛋糕的宽度(左右方向拼接),切面形状和长度不变,蛋糕变得更宽。
核心记忆口诀:
指定哪个维度(dim),该维度的数值相加,其余维度保持不变。
相关函数说明
torch.chunk():用于将张量沿指定维度分割为若干等份块(chunks),常用于分批处理数据。
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor:按照给定索引从输入张量中收集元素,常用于特征提取或索引重构。
torch.reshape():仅改变张量的外形结构,不调整内部元素顺序,适用于维度变换但不改变数据内容的操作。