Yishi Lin

  • Home

  • Archives

  • Dataset

  • Blog

  • Categories

  • Search

因果推断学习笔记(四):经典方法尝试之 Weighting (Lalonde's Dataset)

Posted on 2019-08-18 In 因果推断 , 学习笔记

这篇文章用 Lalonde’s dataset (NSW + PSID + CPS Data) 这个数据集来尝试不同的 weighting 方法。

以下内容由Rmarkdown转来。Git: yishilin14/causal_playground


博客的md用render("./causality4_playaround_with_the_lalonde_dataset_weighting.Rmd", md_document(variant = "markdown_github"))生成~

读数据

数据集

  • Paper: Dehejia R H, Wahba S. Causal effects in nonexperimental studies: Reevaluating the evaluation of training programs[J]. Journal of the American statistical Association, 1999, 94(448): 1053-1062.
  • Download: http://users.nber.org/~rdehejia/nswdata2.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Load the datsets
col.names=c('treat', 'age', 'educ', 'black', 'hispan', 'married', 'nodegree', 're74', 're75', 're78')
dir <- './dataset/nswdata/'

# Combine all the datasets
nsw_data_exp <- rbind(
fread(paste0(dir, 'nswre74_treated.txt'), col.names = col.names),
fread(paste0(dir, 'nswre74_control.txt'), col.names = col.names)
)
nsw_data_obs <- rbind(
fread(paste0(dir, 'nswre74_treated.txt'), col.names = col.names),
fread(paste0(dir, 'cps3_controls.txt'), col.names = col.names)
)

treat.fml <- treat ~ age + educ + black + hispan + married + nodegree + re74 + re75

定义一个用来计算ATT的函数,之后会反复用到。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
estimate.causal.effect <- function(data.name, method.name, unbalance.cnt=0, data, weights) {
d.w <- svydesign(ids = ~1, weights = weights, data = data)
fit <- svyglm(re78 ~ treat, design = d.w)
att <- round(coef(fit)["treat"])
conf <- paste0("(", paste0(round(confint(fit, "treat", 0.95)), collapse = ", "), ")")
tbl <- data.table(
`Data name` = data.name,
`Method name` = method.name,
`# Unbalance Var.` = unbalance.cnt,
`ATT` = att,
`95% Conf. Int.` = conf
)
return(tbl)
}

首先,先看看实验数据的两组用户是否同质。比较诡异的是,实验组和对照组的年龄(age)和受教育程度(educ和nodegree)并不是很接近。两组用户比较之后,ATT为1794美金,下文把这个当作ATT的真值。

1
2
t1 <- bal.tab(treat.fml, data = nsw_data_exp, estimand = "ATT", m.threshold = .1)$Balance
kable(t1[, c("Type", "M.0.Un", "M.1.Un","Diff.Un", "M.Threshold.Un")])
Type M.0.Un M.1.Un Diff.Un M.Threshold.Un
age Contin. 25.0538462 25.8162162 0.1065504 Not Balanced, >0.1
educ Contin. 10.0884615 10.3459459 0.1280603 Not Balanced, >0.1
black Binary 0.8269231 0.8432432 0.0163202 Balanced, <0.1
hispan Binary 0.1076923 0.0594595 -0.0482328 Balanced, <0.1
married Binary 0.1538462 0.1891892 0.0353430 Balanced, <0.1
nodegree Binary 0.8346154 0.7081081 -0.1265073 Not Balanced, >0.1
re74 Contin. 2107.0266585 2095.5736886 -0.0023437 Balanced, <0.1
re75 Contin. 1266.9090025 1532.0553138 0.0823627 Balanced, <0.1
1
2
3
t2 <- estimate.causal.effect("Experimental", "Ground Truth", 3, nsw_data_exp, rep(1, nrow(nsw_data_exp)))
causal.effect.results.all <- t2
kable(t2)
Data name Method name # Unbalance Var. ATT 95% Conf. Int.
Experimental Ground Truth 3 1794 (481, 3108)

接下来看看非实验数据的情况,两组用户的差异很大,不能直接比较。

1
2
t1 <- bal.tab(treat.fml, data = nsw_data_obs, estimand = "ATT", m.threshold = .1)$Balance
kable(t1[, c("Type", "M.0.Un", "M.1.Un","Diff.Un", "M.Threshold.Un")])
Type M.0.Un M.1.Un Diff.Un M.Threshold.Un
age Contin. 28.0303030 25.8162162 -0.3094453 Not Balanced, >0.1
educ Contin. 10.2354312 10.3459459 0.0549647 Balanced, <0.1
black Binary 0.2027972 0.8432432 0.6404460 Not Balanced, >0.1
hispan Binary 0.1421911 0.0594595 -0.0827317 Balanced, <0.1
married Binary 0.5128205 0.1891892 -0.3236313 Not Balanced, >0.1
nodegree Binary 0.5967366 0.7081081 0.1113715 Not Balanced, >0.1
re74 Contin. 5619.2365064 2095.5736886 -0.7210838 Not Balanced, >0.1
re75 Contin. 2466.4844431 1532.0553138 -0.2902629 Not Balanced, >0.1
1
2
3
4
5
unbalance.cnt <- sum(t1$M.Threshold.Un=="Not Balanced, >0.1")
t2 <- estimate.causal.effect("Observational", "Do nothing", unbalance.cnt,
nsw_data_obs, rep(1, nrow(nsw_data_obs)))
causal.effect.results.all <- rbind(causal.effect.results.all, t2)
kable(t2)
Data name Method name # Unbalance Var. ATT 95% Conf. Int.
Observational Do nothing 6 -635 (-1960, 690)

因果推断

这一部分,我们用WeightIt这个包来尝试一下不同的加权方法。

WeightIt-PS

第一个方法是普通的 inverse propensity score weighting。年龄这个属性并不能调整得非常 balance,其它都还好。推断出的ATT为1214美金。

1
2
3
W.out <- weightit(treat.fml, data = nsw_data_obs, estimand = "ATT", method = "ps")
t1 <- bal.tab(W.out, m.threshold = .1)$Balance
kable(t1[, c("Type", "M.0.Adj", "M.1.Adj","Diff.Adj", "M.Threshold")])
Type M.0.Adj M.1.Adj Diff.Adj M.Threshold
prop.score Distance 0.5819524 0.5774355 -0.0205046
age Contin. 24.9658450 25.8162162 0.1188496 Not Balanced, >0.1
educ Contin. 10.4030803 10.3459459 -0.0284159 Balanced, <0.1
black Binary 0.8454795 0.8432432 -0.0022363 Balanced, <0.1
hispan Binary 0.0592923 0.0594595 0.0001672 Balanced, <0.1
married Binary 0.1705802 0.1891892 0.0186090 Balanced, <0.1
nodegree Binary 0.6896872 0.7081081 0.0184209 Balanced, <0.1
re74 Contin. 2106.0448305 2095.5736886 -0.0021428 Balanced, <0.1
re75 Contin. 1496.5412337 1532.0553138 0.0110318 Balanced, <0.1
1
2
3
4
unbalance.cnt <- sum(t1$M.Threshold=="Not Balanced, >0.1")
t2 <- estimate.causal.effect("Observational", "PS", unbalance.cnt, nsw_data_obs, get.w(W.out))
causal.effect.results.all <- rbind(causal.effect.results.all, t2)
kable(t2)
Data name Method name # Unbalance Var. ATT 95% Conf. Int.
Observational PS 1 1214 (-402, 2830)

WeightIt-GBM

第二个方法是gbm,有两个属性不很balance,推断出的ATT也和真值差得十万八千里。个人猜测是这个数据集的数据量太小了,gbm发挥不出威力,又或者是默认的参数导致了过拟合/欠拟合,这里就不深究了。

1
W.out <- weightit(treat.fml, data = nsw_data_obs, estimand = "ATT", method = "gbm")
## Warning: No stop.method was provided. Using "es.mean".

## Warning: Some extreme weights were generated. Examine them with summary()
## and maybe trim them with trim().
1
2
t1 <- bal.tab(W.out, m.threshold = .1)$Balance
kable(t1[, c("Type", "M.0.Adj", "M.1.Adj","Diff.Adj", "M.Threshold")])
Type M.0.Adj M.1.Adj Diff.Adj M.Threshold
prop.score Distance 0.6016014 0.7355983 0.6472613
age Contin. 25.4715156 25.8162162 0.0481761 Balanced, <0.1
educ Contin. 10.4469803 10.3459459 -0.0502496 Balanced, <0.1
black Binary 0.8287765 0.8432432 0.0144667 Balanced, <0.1
hispan Binary 0.0447290 0.0594595 0.0147304 Balanced, <0.1
married Binary 0.1888474 0.1891892 0.0003418 Balanced, <0.1
nodegree Binary 0.6638038 0.7081081 0.0443043 Balanced, <0.1
re74 Contin. 1510.2575345 2095.5736886 0.1197793 Not Balanced, >0.1
re75 Contin. 1074.0766359 1532.0553138 0.1422625 Not Balanced, >0.1
1
2
3
4
unbalance.cnt <- sum(t1$M.Threshold=="Not Balanced, >0.1")
t2 <- estimate.causal.effect("Observational", "GBM", unbalance.cnt, nsw_data_obs, get.w(W.out))
causal.effect.results.all <- rbind(causal.effect.results.all, t2)
kable(t2)
Data name Method name # Unbalance Var. ATT 95% Conf. Int.
Observational GBM 2 461 (-1422, 2344)

WeightIt-CBPS

使用CBPS的时候,所有属性都配平了,ATT也挺好。

1
2
3
W.out <- weightit(treat.fml, data = nsw_data_obs, estimand = "ATT", method = "cbps")
t1 <- bal.tab(W.out, m.threshold = .1)$Balance
kable(t1[, c("Type", "M.0.Adj", "M.1.Adj","Diff.Adj", "M.Threshold")])
Type M.0.Adj M.1.Adj Diff.Adj M.Threshold
prop.score Distance 0.5769492 0.5756980 -0.0056535
age Contin. 25.8533261 25.8162162 -0.0051865 Balanced, <0.1
educ Contin. 10.3493208 10.3459459 -0.0016785 Balanced, <0.1
black Binary 0.8413337 0.8432432 0.0019096 Balanced, <0.1
hispan Binary 0.0596662 0.0594595 -0.0002068 Balanced, <0.1
married Binary 0.1920629 0.1891892 -0.0028737 Balanced, <0.1
nodegree Binary 0.7039155 0.7081081 0.0041926 Balanced, <0.1
re74 Contin. 2133.8214867 2095.5736886 -0.0078270 Balanced, <0.1
re75 Contin. 1512.5767534 1532.0553138 0.0060507 Balanced, <0.1
1
2
3
4
unbalance.cnt <- sum(t1$M.Threshold=="Not Balanced, >0.1")
t2 <- estimate.causal.effect("Observational", "CBPS", unbalance.cnt, nsw_data_obs, get.w(W.out))
causal.effect.results.all <- rbind(causal.effect.results.all, t2)
kable(t2)
Data name Method name # Unbalance Var. ATT 95% Conf. Int.
Observational CBPS 0 1280 (-329, 2890)

WeightIt-EBAL

使用EBAL的时候,所有属性都配平了,而且配平程度令人难以置信!ATT也挺好。

1
2
3
W.out <- weightit(treat.fml, data = nsw_data_obs, estimand = "ATT", method = "ebal")
t1 <- bal.tab(W.out, m.threshold = .1)$Balance
kable(t1[, c("Type", "M.0.Adj", "M.1.Adj","Diff.Adj", "M.Threshold")])
Type M.0.Adj M.1.Adj Diff.Adj M.Threshold
age Contin. 25.8162278 25.8162162 -1.6e-06 Balanced, <0.1
educ Contin. 10.3459478 10.3459459 -9.0e-07 Balanced, <0.1
black Binary 0.8432400 0.8432432 3.3e-06 Balanced, <0.1
hispan Binary 0.0594592 0.0594595 2.0e-07 Balanced, <0.1
married Binary 0.1891903 0.1891892 -1.1e-06 Balanced, <0.1
nodegree Binary 0.7081077 0.7081081 4.0e-07 Balanced, <0.1
re74 Contin. 2095.5851047 2095.5736886 -2.3e-06 Balanced, <0.1
re75 Contin. 1532.0583946 1532.0553138 -1.0e-06 Balanced, <0.1
1
2
3
4
unbalance.cnt <- sum(t1$M.Threshold=="Not Balanced, >0.1")
t2 <- estimate.causal.effect("Observational", "EBAL", unbalance.cnt, nsw_data_obs, get.w(W.out))
causal.effect.results.all <- rbind(causal.effect.results.all, t2)
kable(t2)
Data name Method name # Unbalance Var. ATT 95% Conf. Int.
Observational EBAL 0 1273 (-344, 2890)

WeightIt-EBCW

使用EBCW的时候,所有属性都配平了,而且配平程度令人难以置信!ATT也挺好。

1
2
3
W.out <- weightit(treat.fml, data = nsw_data_obs, estimand = "ATT", method = "ebcw")
t1 <- bal.tab(W.out, m.threshold = .1)$Balance
kable(t1[, c("Type", "M.0.Adj", "M.1.Adj","Diff.Adj", "M.Threshold")])
Type M.0.Adj M.1.Adj Diff.Adj M.Threshold
age Contin. 25.8162162 25.8162162 0 Balanced, <0.1
educ Contin. 10.3459459 10.3459459 0 Balanced, <0.1
black Binary 0.8432432 0.8432432 0 Balanced, <0.1
hispan Binary 0.0594595 0.0594595 0 Balanced, <0.1
married Binary 0.1891892 0.1891892 0 Balanced, <0.1
nodegree Binary 0.7081081 0.7081081 0 Balanced, <0.1
re74 Contin. 2095.5736938 2095.5736886 0 Balanced, <0.1
re75 Contin. 1532.0553150 1532.0553138 0 Balanced, <0.1
1
2
3
4
unbalance.cnt <- sum(t1$M.Threshold=="Not Balanced, >0.1")
t2 <- estimate.causal.effect("Observational", "EBCW", unbalance.cnt, nsw_data_obs, get.w(W.out))
causal.effect.results.all <- rbind(causal.effect.results.all, t2)
kable(t2)
Data name Method name # Unbalance Var. ATT 95% Conf. Int.
Observational EBCW 0 1273 (-344, 2890)

WeightIt-OptWeight

使用optweight的时候,所有属性都配平了,配平程度也是完美的。ATT也挺好。

1
2
3
W.out <- weightit(treat.fml, data = nsw_data_obs, estimand = "ATT", method = "optweight")
t1 <- bal.tab(W.out, m.threshold = .1)$Balance
kable(t1[, c("Type", "M.0.Adj", "M.1.Adj","Diff.Adj", "M.Threshold")])
Type M.0.Adj M.1.Adj Diff.Adj M.Threshold
age Contin. 25.8162166 25.8162162 0e+00 Balanced, <0.1
educ Contin. 10.3459459 10.3459459 0e+00 Balanced, <0.1
black Binary 0.8432432 0.8432432 0e+00 Balanced, <0.1
hispan Binary 0.0594595 0.0594595 0e+00 Balanced, <0.1
married Binary 0.1891892 0.1891892 0e+00 Balanced, <0.1
nodegree Binary 0.7081081 0.7081081 0e+00 Balanced, <0.1
re74 Contin. 2095.5742123 2095.5736886 -1e-07 Balanced, <0.1
re75 Contin. 1532.0554281 1532.0553138 0e+00 Balanced, <0.1
1
2
3
4
unbalance.cnt <- sum(t1$M.Threshold=="Not Balanced, >0.1")
t2 <- estimate.causal.effect("Observational", "OptWeight", unbalance.cnt, nsw_data_obs, get.w(W.out))
causal.effect.results.all <- rbind(causal.effect.results.all, t2)
kable(t2)
Data name Method name # Unbalance Var. ATT 95% Conf. Int.
Observational OptWeight 0 1204 (-399, 2808)

汇总结果

汇总一下结果,结束~

1
kable(causal.effect.results.all)
Data name Method name # Unbalance Var. ATT 95% Conf. Int.
Experimental Ground Truth 3 1794 (481, 3108)
Observational Do nothing 6 -635 (-1960, 690)
Observational PS 1 1214 (-402, 2830)
Observational GBM 2 461 (-1422, 2344)
Observational CBPS 0 1280 (-329, 2890)
Observational EBAL 0 1273 (-344, 2890)
Observational EBCW 0 1273 (-344, 2890)
Observational OptWeight 0 1204 (-399, 2808)
# causality # 因果推断
因果推断学习笔记(三):因果推断比赛 ACIC Data Challenge 学习笔记
因果推断漫谈(三):倾向性得分加权介绍
  • Table of Contents
  • Overview
Yishi Lin

Yishi Lin

24 posts
11 categories
25 tags
RSS
GitHub E-Mail
  1. 1. 读数据
  2. 2. 因果推断
    1. 2.1. WeightIt-PS
    2. 2.2. WeightIt-GBM
    3. 2.3. WeightIt-CBPS
    4. 2.4. WeightIt-EBAL
    5. 2.5. WeightIt-EBCW
    6. 2.6. WeightIt-OptWeight
    7. 2.7. 汇总结果
© 2013 – 2021 Yishi Lin
Powered by Hexo v3.9.0
|
Theme – NexT.Gemini v7.3.0